From 3377f3f433b22205462ab2eec3232e574b3b715e Mon Sep 17 00:00:00 2001 From: Jan Valosek <39456460+valosekj@users.noreply.github.com> Date: Tue, 30 Apr 2024 18:58:07 +0200 Subject: [PATCH] Add python script and quick start guide for MetricsReloaded (#46) * Add python script and quick start guide for MetricsReloaded * Add unittests for compute_metrics_reloaded.py --- .github/workflows/ci.yml | 6 +- compute_metrics/compute_metrics_reloaded.py | 255 ++++++++++++++++++ .../MetricsReloaded_quick_start_guide.md | 58 ++++ tests/test_compute_metrics_reloaded.py | 211 +++++++++++++++ 4 files changed, 529 insertions(+), 1 deletion(-) create mode 100644 compute_metrics/compute_metrics_reloaded.py create mode 100644 quick_start_guides/MetricsReloaded_quick_start_guide.md create mode 100644 tests/test_compute_metrics_reloaded.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6d11199..f1f2a89 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -34,7 +34,11 @@ jobs: run: | python -m pip install --upgrade pip pip install -r dataset_conversion/requirements.txt + git clone https://github.com/ivadomed/MetricsReloaded.git + cd MetricsReloaded + python -m pip install . - name: Run tests with unittest run: | - python -m unittest tests/test_convert_bids_to_nnUNetV2.py \ No newline at end of file + python -m unittest tests/test_convert_bids_to_nnUNetV2.py + python -m unittest tests/test_compute_metrics_reloaded.py \ No newline at end of file diff --git a/compute_metrics/compute_metrics_reloaded.py b/compute_metrics/compute_metrics_reloaded.py new file mode 100644 index 0000000..b859d24 --- /dev/null +++ b/compute_metrics/compute_metrics_reloaded.py @@ -0,0 +1,255 @@ +""" +Compute MetricsReloaded metrics for segmentation tasks. +Details: https://github.com/ivadomed/MetricsReloaded + +Example usage (single reference-prediction pair): + python compute_metrics_reloaded.py + -reference sub-001_T2w_seg.nii.gz + -prediction sub-001_T2w_prediction.nii.gz + +Example usage (multiple reference-prediction pairs): + python compute_metrics_reloaded.py + -reference /path/to/reference + -prediction /path/to/prediction + +Default metrics (semantic segmentation): + - Dice similarity coefficient (DSC) + - Normalized surface distance (NSD) +(for details, see Fig. 2, Fig. 11, and Fig. 12 in https://arxiv.org/abs/2206.01653v5) + +Dice similarity coefficient (DSC): +- Fig. 65 in https://arxiv.org/pdf/2206.01653v5.pdf +- https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/pairwise_measures.html#MetricsReloaded.metrics.pairwise_measures.BinaryPairwiseMeasures.dsc +Normalized surface distance (NSD): +- Fig. 86 in https://arxiv.org/pdf/2206.01653v5.pdf +- https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/pairwise_measures.html#MetricsReloaded.metrics.pairwise_measures.BinaryPairwiseMeasures.normalised_surface_distance + +The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based). +The metrics are computed for each unique label (class) in the reference (ground truth) image. +The output is saved to a CSV file, for example: + +reference prediction label dsc fbeta nsd vol_diff rel_vol_diff EmptyRef EmptyPred +seg.nii.gz pred.nii.gz 1.0 0.819 0.819 0.945 0.105 -10.548 False False +seg.nii.gz pred.nii.gz 2.0 0.743 0.743 0.923 0.121 -11.423 False False + +Authors: Jan Valosek +""" + + +import os +import argparse +import numpy as np +import nibabel as nib +import pandas as pd + +from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM + + +METRICS_TO_NAME = { + 'dsc': 'Dice similarity coefficient (DSC)', + 'hd': 'Hausdorff distance', + 'fbeta': 'FBeta score', + 'nsd': 'Normalized surface distance (NSD)', + 'vol_diff': 'Volume difference', + 'rel_vol_diff': 'Relative volume error (RVE)', +} + + +def get_parser(): + # parse command line arguments + parser = argparse.ArgumentParser(description='Compute MetricsReloaded metrics for segmentation tasks.') + + # Arguments for model, data, and training + parser.add_argument('-prediction', required=True, type=str, + help='Path to the folder with nifti images of test predictions or path to a single nifti image ' + 'of test prediction.') + parser.add_argument('-reference', required=True, type=str, + help='Path to the folder with nifti images of reference (ground truth) or path to a single ' + 'nifti image of reference (ground truth).') + parser.add_argument('-metrics', nargs='+', default=['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error'], + required=False, + help='List of metrics to compute. For details, ' + 'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html. ' + 'Default: dsc, fbeta, nsd, vol_diff, rel_vol_error') + parser.add_argument('-output', type=str, default='metrics.csv', required=False, + help='Path to the output CSV file to save the metrics. Default: metrics.csv') + + return parser + + +def load_nifti_image(file_path): + """ + Construct absolute path to the nifti image, check if it exists, and load the image data. + :param file_path: path to the nifti image + :return: nifti image data + """ + file_path = os.path.expanduser(file_path) # resolve '~' in the path + file_path = os.path.abspath(file_path) + if not os.path.exists(file_path): + raise FileNotFoundError(f'File {file_path} does not exist.') + nifti_image = nib.load(file_path) + return nifti_image.get_fdata() + + +def get_images_in_folder(prediction, reference): + """ + Get all files (predictions and references/ground truths) in the input directories + :param prediction: path to the directory with prediction files + :param reference: path to the directory with reference (ground truth) files + :return: list of prediction files, list of reference/ground truth files + """ + # Get all files in the directories + prediction_files = [os.path.join(prediction, f) for f in os.listdir(prediction) if f.endswith('.nii.gz')] + reference_files = [os.path.join(reference, f) for f in os.listdir(reference) if f.endswith('.nii.gz')] + # Check if the number of files in the directories is the same + if len(prediction_files) != len(reference_files): + raise ValueError(f'The number of files in the directories is different. ' + f'Prediction files: {len(prediction_files)}, Reference files: {len(reference_files)}') + print(f'Found {len(prediction_files)} files in the directories.') + # Sort the files + # NOTE: Hopefully, the files are named in the same order in both directories + prediction_files.sort() + reference_files.sort() + + return prediction_files, reference_files + + +def compute_metrics_single_subject(prediction, reference, metrics): + """ + Compute MetricsReloaded metrics for a single subject + :param prediction: path to the nifti image with the prediction + :param reference: path to the nifti image with the reference (ground truth) + :param metrics: list of metrics to compute + """ + # load nifti images + prediction_data = load_nifti_image(prediction) + reference_data = load_nifti_image(reference) + + # check whether the images have the same shape and orientation + if prediction_data.shape != reference_data.shape: + raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. ' + f'The prediction image has shape {prediction_data.shape} and the ground truth image has ' + f'shape {reference_data.shape}.') + + # get all unique labels (classes) + # for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2 + unique_labels_reference = np.unique(reference_data) + unique_labels_reference = unique_labels_reference[unique_labels_reference != 0] # remove background + unique_labels_prediction = np.unique(prediction_data) + unique_labels_prediction = unique_labels_prediction[unique_labels_prediction != 0] # remove background + + # Get the unique labels that are present in the reference OR prediction images + unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction))) + + # append entry into the output_list to store the metrics for the current subject + metrics_dict = {'reference': reference, 'prediction': prediction} + + # loop over all unique labels + for label in unique_labels: + # create binary masks for the current label + print(f'Processing label {label}') + prediction_data_label = np.array(prediction_data == label, dtype=float) + reference_data_label = np.array(reference_data == label, dtype=float) + + bpm = BPM(prediction_data_label, reference_data_label, measures=metrics) + dict_seg = bpm.to_dict_meas() + # Store info whether the reference or prediction is empty + dict_seg['EmptyRef'] = bpm.flag_empty_ref + dict_seg['EmptyPred'] = bpm.flag_empty_pred + # add the metrics to the output dictionary + metrics_dict[label] = dict_seg + + if label == max(unique_labels): + break # break to loop to avoid processing the background label ("else" block) + # Special case when both the reference and prediction images are empty + else: + label = 1 + print(f'Processing label {label} -- both the reference and prediction are empty') + bpm = BPM(prediction_data, reference_data, measures=metrics) + dict_seg = bpm.to_dict_meas() + + # Store info whether the reference or prediction is empty + dict_seg['EmptyRef'] = bpm.flag_empty_ref + dict_seg['EmptyPred'] = bpm.flag_empty_pred + # add the metrics to the output dictionary + metrics_dict[label] = dict_seg + + return metrics_dict + + +def build_output_dataframe(output_list): + """ + Convert JSON data to pandas DataFrame + :param output_list: list of dictionaries with metrics + :return: pandas DataFrame + """ + rows = [] + for item in output_list: + # Extract all keys except 'reference' and 'prediction' to get labels (e.g. 1.0, 2.0, etc.) dynamically + labels = [key for key in item.keys() if key not in ['reference', 'prediction']] + for label in labels: + metrics = item[label] # Get the dictionary of metrics + # Dynamically add all metrics for the label + row = { + "reference": item["reference"], + "prediction": item["prediction"], + "label": label, + } + # Update row with all metrics dynamically + row.update(metrics) + rows.append(row) + + df = pd.DataFrame(rows) + + return df + + +def main(): + + # parse command line arguments + parser = get_parser() + args = parser.parse_args() + + # Initialize a list to store the output dictionaries (representing a single reference-prediction pair per subject) + output_list = list() + + # Args.prediction and args.reference are paths to folders with multiple nii.gz files (i.e., multiple subjects) + if os.path.isdir(args.prediction) and os.path.isdir(args.reference): + # Get all files in the directories + prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference) + # Loop over the subjects + for i in range(len(prediction_files)): + # Compute metrics for each subject + metrics_dict = compute_metrics_single_subject(prediction_files[i], reference_files[i], args.metrics) + # Append the output dictionary (representing a single reference-prediction pair per subject) to the + # output_list + output_list.append(metrics_dict) + # Args.prediction and args.reference are paths nii.gz files from a single subject + else: + metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics) + # Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list + output_list.append(metrics_dict) + + # Convert JSON data to pandas DataFrame + df = build_output_dataframe(output_list) + + # Rename columns + df.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True) + + # save as CSV + fname_output_csv = os.path.abspath(args.output) + df.to_csv(fname_output_csv, index=False) + print(f'Saved metrics to {fname_output_csv}.') + + # Compute mean and standard deviation of metrics across all subjects + df_mean = (df.drop(columns=['reference', 'prediction', 'EmptyRef', 'EmptyPred']).groupby('label'). + agg(['mean', 'std']).reset_index()) + + # save as CSV + fname_output_csv_mean = os.path.abspath(args.output.replace('.csv', '_mean.csv')) + df_mean.to_csv(fname_output_csv_mean, index=False) + print(f'Saved mean and standard deviation of metrics across all subjects to {fname_output_csv_mean}.') + + +if __name__ == '__main__': + main() diff --git a/quick_start_guides/MetricsReloaded_quick_start_guide.md b/quick_start_guides/MetricsReloaded_quick_start_guide.md new file mode 100644 index 0000000..63127f7 --- /dev/null +++ b/quick_start_guides/MetricsReloaded_quick_start_guide.md @@ -0,0 +1,58 @@ +# MetricsReloaded quick-start guide + +Useful links: +- [MetricsReloaded documentation](https://metricsreloaded.readthedocs.io/en/latest/) +- [MetricsReloaded publication](https://www.nature.com/articles/s41592-023-02151-z) +- [MetricsReloaded preprint](https://arxiv.org/pdf/2206.01653v5.pdf) - preprint contains more figures than the publication + +## Installation + +The installation instructions are available [here](https://github.com/ivadomed/MetricsReloaded?tab=readme-ov-file#installation). + +> **Note** +> Note that we use an ivadomed fork. + + +> **Note** +> Always install MetricsReloaded inside a virtual environment. + +``` +# Create and activate a new conda environment +conda create -n metrics_reloaded python=3.10 pip +conda activate metrics_reloaded + +# Clone the repository +cd ~/code +git clone https://github.com/ivadomed/MetricsReloaded +cd MetricsReloaded + +# Install the package +python -m pip install . +# You can alternatively install the package in editable mode: +python -m pip install -e . +``` + +## Usage + +You can use the [compute_metrics_reloaded.py](../compute_metrics/compute_metrics_reloaded.py) script to compute metrics using the MetricsReloaded package. + +```commandline +python compute_metrics_reloaded.py -reference sub-001_T2w_seg.nii.gz -prediction sub-001_T2w_prediction.nii.gz +``` + +Default metrics (semantic segmentation): + - Dice similarity coefficient (DSC) + - Normalized surface distance (NSD) +(for details, see Fig. 2, Fig. 11, and Fig. 12 in https://arxiv.org/abs/2206.01653v5) + +The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based). + +The metrics are computed for each unique label (class) in the reference (ground truth) image. + +The output is saved to a CSV file, for example: + +```csv +reference prediction label dsc fbeta nsd vol_diff rel_vol_diff EmptyRef EmptyPred +seg.nii.gz pred.nii.gz 1.0 0.819 0.819 0.945 0.105 -10.548 False False +seg.nii.gz pred.nii.gz 2.0 0.743 0.743 0.923 0.121 -11.423 False False +``` \ No newline at end of file diff --git a/tests/test_compute_metrics_reloaded.py b/tests/test_compute_metrics_reloaded.py new file mode 100644 index 0000000..4cc96fe --- /dev/null +++ b/tests/test_compute_metrics_reloaded.py @@ -0,0 +1,211 @@ +####################################################################### +# +# Tests for the `compute_metrics/compute_metrics_reloaded.py` script +# +# RUN BY: +# python -m unittest tests/test_compute_metrics_reloaded.py +####################################################################### + +import unittest +import os +import numpy as np +import nibabel as nib +from compute_metrics.compute_metrics_reloaded import compute_metrics_single_subject +import tempfile + +METRICS = ['dsc', 'fbeta', 'nsd', 'vol_diff', 'rel_vol_error'] + + +class TestComputeMetricsReloaded(unittest.TestCase): + def setUp(self): + # Use tempfile.NamedTemporaryFile to create temporary nifti files + self.ref_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) + self.pred_file = tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) + self.metrics = METRICS + + def create_dummy_nii(self, file_obj, data): + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, file_obj.name) + file_obj.seek(0) # Move back to the beginning of the file + + def tearDown(self): + # Close and remove temporary files + self.ref_file.close() + os.unlink(self.ref_file.name) + self.pred_file.close() + os.unlink(self.pred_file.name) + + def assert_metrics(self, metrics_dict, expected_metrics): + for metric in self.metrics: + # Loop over labels/classes (e.g., 1, 2, ...) + for label in expected_metrics.keys(): + # if value is nan, use np.isnan to check + if np.isnan(expected_metrics[label][metric]): + self.assertTrue(np.isnan(metrics_dict[label][metric])) + # if value is inf, use np.isinf to check + elif np.isinf(expected_metrics[label][metric]): + self.assertTrue(np.isinf(metrics_dict[label][metric])) + else: + self.assertAlmostEqual(metrics_dict[label][metric], expected_metrics[label][metric]) + + def test_empty_ref_and_pred(self): + """ + Empty reference and empty prediction + """ + + expected_metrics = {1.0: {'EmptyPred': True, + 'EmptyRef': True, + 'dsc': 1, + 'fbeta': 1, + 'nsd': np.nan, + 'rel_vol_error': 0, + 'vol_diff': np.nan}} + + # Create empty reference + self.create_dummy_nii(self.ref_file, np.zeros((10, 10, 10))) + # Create empty prediction + self.create_dummy_nii(self.pred_file, np.zeros((10, 10, 10))) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + + def test_empty_ref(self): + """ + Empty reference and non-empty prediction + """ + + expected_metrics = {1.0: {'EmptyPred': False, + 'EmptyRef': True, + 'dsc': 0.0, + 'fbeta': 0, + 'nsd': 0.0, + 'rel_vol_error': 100, + 'vol_diff': np.inf}} + + # Create empty reference + self.create_dummy_nii(self.ref_file, np.zeros((10, 10, 10))) + # Create non-empty prediction + pred = np.zeros((10, 10, 10)) + pred[5:7, 2:5] = 1 + self.create_dummy_nii(self.pred_file, pred) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + + def test_empty_pred(self): + """ + Non-empty reference and empty prediction + """ + + expected_metrics = {1.0: {'EmptyPred': True, + 'EmptyRef': False, + 'dsc': 0.0, + 'fbeta': 0, + 'nsd': 0.0, + 'rel_vol_error': -100.0, + 'vol_diff': 1.0}} + + # Create non-empty reference + ref = np.zeros((10, 10, 10)) + ref[5:7, 2:5] = 1 + self.create_dummy_nii(self.ref_file, ref) + # Create empty prediction + self.create_dummy_nii(self.pred_file, np.zeros((10, 10, 10))) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + + def test_non_empty_ref_and_pred(self): + """ + Non-empty reference and non-empty prediction with partial overlap + """ + + expected_metrics = {1.0: {'EmptyPred': False, + 'EmptyRef': False, + 'dsc': 0.26666666666666666, + 'fbeta': 0.26666667461395266, + 'nsd': 0.5373134328358209, + 'rel_vol_error': 300.0, + 'vol_diff': 3.0}} + + # Create non-empty reference + ref = np.zeros((10, 10, 10)) + ref[4:5, 3:6] = 1 + self.create_dummy_nii(self.ref_file, ref) + # Create non-empty prediction + pred = np.zeros((10, 10, 10)) + pred[4:8, 2:5] = 1 + self.create_dummy_nii(self.pred_file, pred) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + + def test_non_empty_ref_and_pred_multi_class(self): + """ + Non-empty reference and non-empty prediction with partial overlap + Multi-class (i.e., voxels with values 1 and 2, e.g., region-based nnUNet training) + """ + + expected_metrics = {1.0: {'dsc': 0.25, + 'fbeta': 0.2500000055879354, + 'nsd': 0.5, + 'vol_diff': 2.0, + 'rel_vol_error': 200.0, + 'EmptyRef': False, + 'EmptyPred': False}, + 2.0: {'dsc': 0.26666666666666666, + 'fbeta': 0.26666667461395266, + 'nsd': 0.5373134328358209, + 'vol_diff': 3.0, + 'rel_vol_error': 300.0, + 'EmptyRef': False, + 'EmptyPred': False}} + + # Create non-empty reference + ref = np.zeros((10, 10, 10)) + ref[4:5, 3:10] = 1 + ref[4:5, 3:6] = 2 # e.g., lesion within spinal cord + self.create_dummy_nii(self.ref_file, ref) + # Create non-empty prediction + pred = np.zeros((10, 10, 10)) + pred[4:8, 2:8] = 1 + pred[4:8, 2:5] = 2 # e.g., lesion within spinal cord + self.create_dummy_nii(self.pred_file, pred) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + + def test_non_empty_ref_and_pred_with_full_overlap(self): + """ + Non-empty reference and non-empty prediction with full overlap + """ + + expected_metrics = {1.0: {'EmptyPred': False, + 'EmptyRef': False, + 'dsc': 1.0, + 'fbeta': 1.0, + 'nsd': 1.0, + 'rel_vol_error': 0.0, + 'vol_diff': 0.0}} + + # Create non-empty reference + ref = np.zeros((10, 10, 10)) + ref[4:8, 2:5] = 1 + self.create_dummy_nii(self.ref_file, ref) + # Create non-empty prediction + pred = np.zeros((10, 10, 10)) + pred[4:8, 2:5] = 1 + self.create_dummy_nii(self.pred_file, pred) + # Compute metrics + metrics_dict = compute_metrics_single_subject(self.pred_file.name, self.ref_file.name, self.metrics) + # Assert metrics + self.assert_metrics(metrics_dict, expected_metrics) + + +if __name__ == '__main__': + unittest.main()