diff --git a/test/test_metrics/test_pairwise_measures_neuropoly.py b/test/test_metrics/test_pairwise_measures_neuropoly.py index 60da2be..4067526 100644 --- a/test/test_metrics/test_pairwise_measures_neuropoly.py +++ b/test/test_metrics/test_pairwise_measures_neuropoly.py @@ -13,7 +13,7 @@ import os import numpy as np import nibabel as nib -from compute_metrics_reloaded import compute_metrics_single_subject +from compute_metrics_reloaded import compute_metrics_single_subject, get_images, fetch_participant_id_acq_id_run_id import tempfile METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error', 'lesion_ppv', 'lesion_sensitivity', 'lesion_f1_score', @@ -358,6 +358,100 @@ def test_non_empty_ref_and_pred_with_full_overlap(self): # Assert metrics self.assert_metrics(metrics_dict, expected_metrics) +class TestGetImages(unittest.TestCase): + def setUp(self): + """ + Create temporary directories and files for testing. + """ + self.pred_dir = tempfile.TemporaryDirectory() + self.ref_dir = tempfile.TemporaryDirectory() + + def tearDown(self): + """ + Cleanup temporary directories and files after tests. + """ + self.pred_dir.cleanup() + self.ref_dir.cleanup() + + def create_temp_file(self, directory, filename): + """ + Create a temporary file in the given directory with the specified filename. + """ + file_path = os.path.join(directory, filename) + with open(file_path, 'w') as f: + f.write('dummy content') + return file_path + + def test_matching_files(self): + """ + Test matching files based on participant_id, acq_id, and run_id. + """ + self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz") + self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_run-01_ref.nii.gz") + + pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name) + self.assertEqual(len(pred_files), 1) + self.assertEqual(len(ref_files), 1) + + def test_mismatched_files(self): + """ + Test when no files match based on the criteria. + """ + self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz") + self.create_temp_file(self.ref_dir.name, "sub-02_acq-02_run-02_ref.nii.gz") + + pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name) + self.assertEqual(len(pred_files), 0) + self.assertEqual(len(ref_files), 0) + + def test_acq_id_empty(self): + """ + Test when acq_id is empty. + """ + self.create_temp_file(self.pred_dir.name, "sub-01_run-01_pred.nii.gz") + self.create_temp_file(self.ref_dir.name, "sub-01_run-01_ref.nii.gz") + + pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name) + self.assertEqual(len(pred_files), 1) + self.assertEqual(len(ref_files), 1) + + def test_run_id_empty(self): + """ + Test when run_id is empty in the filenames. + """ + self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_pred.nii.gz") + self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_ref.nii.gz") + + pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name) + + # Assert the matched files + self.assertEqual(len(pred_files), 1) + self.assertEqual(len(ref_files), 1) + self.assertIn("sub-01_acq-01_pred.nii.gz", pred_files[0]) + self.assertIn("sub-01_acq-01_ref.nii.gz", ref_files[0]) + + def test_no_files(self): + """ + Test when there are no files in the directories. + Ensure that FileNotFoundError is raised. + """ + with self.assertRaises(FileNotFoundError) as context: + get_images(self.pred_dir.name, self.ref_dir.name) + # Check the exception message + self.assertIn(f'No prediction files found in {self.pred_dir.name}', str(context.exception)) + + def test_partial_matching(self): + """ + Test when some files match and some do not. + """ + self.create_temp_file(self.pred_dir.name, "sub-01_acq-01_run-01_pred.nii.gz") + self.create_temp_file(self.ref_dir.name, "sub-01_acq-01_run-01_ref.nii.gz") + self.create_temp_file(self.pred_dir.name, "sub-02_acq-02_run-02_pred.nii.gz") + + pred_files, ref_files = get_images(self.pred_dir.name, self.ref_dir.name) + self.assertEqual(len(pred_files), 1) + self.assertEqual(len(ref_files), 1) + if __name__ == '__main__': unittest.main()