Skip to content

Commit

Permalink
add unittets to test the newly proposed matching based on participant…
Browse files Browse the repository at this point in the history
…_id, acq_id, and run_id
  • Loading branch information
valosekj committed Dec 10, 2024
1 parent 2bc7297 commit 0c61d12
Showing 1 changed file with 95 additions and 1 deletion.
96 changes: 95 additions & 1 deletion test/test_metrics/test_pairwise_measures_neuropoly.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import numpy as np
import nibabel as nib
from compute_metrics_reloaded import compute_metrics_single_subject
from compute_metrics_reloaded import compute_metrics_single_subject, get_images, fetch_participant_id_acq_id_run_id
import tempfile

METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error', 'lesion_ppv', 'lesion_sensitivity', 'lesion_f1_score',
Expand Down Expand Up @@ -358,6 +358,100 @@ def test_non_empty_ref_and_pred_with_full_overlap(self):
# Assert metrics
self.assert_metrics(metrics_dict, expected_metrics)

class TestGetImages(unittest.TestCase):
def setUp(self):
"""
Create temporary directories and files for testing.
"""
self.pred_dir = tempfile.TemporaryDirectory()
self.ref_dir = tempfile.TemporaryDirectory()

def tearDown(self):
"""
Cleanup temporary directories and files after tests.
"""
self.pred_dir.cleanup()
self.ref_dir.cleanup()

def create_temp_file(self, directory, filename):
"""
Create a temporary file in the given directory with the specified filename.
"""
file_path = os.path.join(directory, filename)
with open(file_path, 'w') as f:
f.write('dummy content')
return file_path

def test_matching_files(self):
"""
Test matching files based on participant_id, acq_id, and run_id.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)

def test_mismatched_files(self):
"""
Test when no files match based on the criteria.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-02_acq-02_run-02_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 0)
self.assertEqual(len(ref_files), 0)

def test_acq_id_empty(self):
"""
Test when acq_id is empty.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_run-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)

def test_run_id_empty(self):
"""
Test when run_id is empty in the filenames.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_ref.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)

# Assert the matched files
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)
self.assertIn("sub-01_acq-01_pred.nii.gz", pred_files[0])
self.assertIn("sub-01_acq-01_ref.nii.gz", ref_files[0])

def test_no_files(self):
"""
Test when there are no files in the directories.
Ensure that FileNotFoundError is raised.
"""
with self.assertRaises(FileNotFoundError) as context:
get_images(self.pred_dir.name, self.ref_dir.name)
# Check the exception message
self.assertIn(f'No prediction files found in {self.pred_dir.name}', str(context.exception))

def test_partial_matching(self):
"""
Test when some files match and some do not.
"""
self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz")
self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_run-01_ref.nii.gz")
self.create_temp_file(self.pred_dir.name, "sub-02_acq-02_run-02_pred.nii.gz")

pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name)
self.assertEqual(len(pred_files), 1)
self.assertEqual(len(ref_files), 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 0c61d12

Please sign in to comment.