Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix group-wise evaluation #12

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
143 changes: 112 additions & 31 deletions compute_metrics_reloaded.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
The script is compatible with both binary and multi-class segmentation tasks (e.g., nnunet region-based).
The metrics are computed for each unique label (class) in the reference (ground truth) image.

Authors: Jan Valosek
Authors: Jan Valosek, Naga Karthik
"""


Expand All @@ -41,6 +41,8 @@
import numpy as np
import nibabel as nib
import pandas as pd
import re
from tqdm import tqdm

from MetricsReloaded.metrics.pairwise_measures import BinaryPairwiseMeasures as BPM

Expand Down Expand Up @@ -81,6 +83,8 @@ def get_parser():
'see: https://metricsreloaded.readthedocs.io/en/latest/reference/metrics/metrics.html.')
parser.add_argument('-output', type=str, default='metrics.csv', required=False,
help='Path to the output CSV file to save the metrics. Default: metrics.csv')
parser.add_argument('-mask-type', type=str, default='chunks', required=False,
help='Type of the labels in the images. Options: [chunks, stitched]')

return parser

Expand Down Expand Up @@ -122,25 +126,13 @@ def get_images_in_folder(prediction, reference):
return prediction_files, reference_files


def compute_metrics_single_subject(prediction, reference, metrics):
def compute_metrics_single_subject(prediction_data, reference_data, metrics, metrics_dict):
"""
Compute MetricsReloaded metrics for a single subject
:param prediction: path to the nifti image with the prediction
:param reference: path to the nifti image with the reference (ground truth)
:param prediction: numpy array of the prediction mask
:param reference: numpy array of the reference mask (ground truth)
:param metrics: list of metrics to compute
"""
# load nifti images
print(f'Processing...')
print(f'\tPrediction: {os.path.basename(prediction)}')
print(f'\tReference: {os.path.basename(reference)}')
prediction_data = load_nifti_image(prediction)
reference_data = load_nifti_image(reference)

# check whether the images have the same shape and orientation
if prediction_data.shape != reference_data.shape:
raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. '
f'The prediction image has shape {prediction_data.shape} and the ground truth image has '
f'shape {reference_data.shape}.')

# get all unique labels (classes)
# for example, for nnunet region-based segmentation, spinal cord has label 1, and lesions have label 2
Expand All @@ -152,14 +144,10 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# Get the unique labels that are present in the reference OR prediction images
unique_labels = np.unique(np.concatenate((unique_labels_reference, unique_labels_prediction)))

# append entry into the output_list to store the metrics for the current subject
metrics_dict = {'reference': reference, 'prediction': prediction}

# loop over all unique labels, e.g., voxels with values 1, 2, ...
# by doing this, we can compute metrics for each label separately, e.g., separately for spinal cord and lesions
for label in unique_labels:
# create binary masks for the current label
print(f'\tLabel {label}')
prediction_data_label = np.array(prediction_data == label, dtype=float)
reference_data_label = np.array(reference_data == label, dtype=float)

Expand All @@ -176,7 +164,6 @@ def compute_metrics_single_subject(prediction, reference, metrics):
# Special case when both the reference and prediction images are empty
else:
label = 1
print(f'\tLabel {label} -- both the reference and prediction are empty')
bpm = BPM(prediction_data, reference_data, measures=metrics)
dict_seg = bpm.to_dict_meas()

Expand Down Expand Up @@ -216,8 +203,22 @@ def build_output_dataframe(output_list):
return df


def main():
def find_subject_session_chunk_in_path(path):
"""
Extracts subject and session identifiers from the given path.
:param path: Input path containing subject and session identifiers.
:return: Extracted subject and session identifiers or None if not found.
"""
# pattern = r'.*_(sub-m\d{6})_(ses-\d{8}).*_(chunk-\d{1})_.*'
pattern = r'.*_(sub-m\d{6}_ses-\d{8}).*_(chunk-\d{1})_.*'
match = re.search(pattern, path)
if match:
return match.group(1), match.group(2)
else:
return None, None, None


def main():
# parse command line arguments
parser = get_parser()
args = parser.parse_args()
Expand All @@ -232,26 +233,106 @@ def main():
if os.path.isdir(args.prediction) and os.path.isdir(args.reference):
# Get all files in the directories
prediction_files, reference_files = get_images_in_folder(args.prediction, args.reference)
# Loop over the subjects
for i in range(len(prediction_files)):
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(prediction_files[i], reference_files[i], args.metrics)
# Append the output dictionary (representing a single reference-prediction pair per subject) to the
# output_list
output_list.append(metrics_dict)

if args.mask_type == 'chunks':

# get the subject, session, and chunk identifiers from the path
subjects_sessions = [find_subject_session_chunk_in_path(f)[0] for f in prediction_files if find_subject_session_chunk_in_path(f)]
subjects_sessions = list(set(subjects_sessions))

for sub_ses in tqdm(subjects_sessions, desc='Computing metrics for each subject'):
preds_per_sub_ses = [f for f in prediction_files if sub_ses in f]
refs_per_sub_ses = [f for f in reference_files if sub_ses in f]

preds_stack, refs_stack = [], []
for pred, ref in zip(preds_per_sub_ses, refs_per_sub_ses):
# load nifti images
prediction_data = load_nifti_image(pred)
reference_data = load_nifti_image(ref)

# check whether the images have the same shape and orientation
if prediction_data.shape != reference_data.shape:
raise ValueError(f'The prediction and reference (ground truth) images must have the same shape. '
f'The prediction image has shape {prediction_data.shape} and the ground truth image has '
f'shape {reference_data.shape}.')

preds_stack.append(prediction_data)
refs_stack.append(reference_data)

# min_shape = np.min([pred.shape for pred in preds_stack], axis=0)
max_shape = np.max([pred.shape for pred in preds_stack], axis=0)
max_shape_ref = np.max([ref.shape for ref in refs_stack], axis=0)

assert max_shape[0] == max_shape_ref[0], "The images must have the same shape at dim[0]"
assert max_shape[1] == max_shape_ref[1], "The images must have the same shape at dim[1]"
assert max_shape[2] == max_shape_ref[2], "The images must have the same shape at dim[2]"

# pad the images to the same shape
preds_stack = [np.pad(pred, ((0, max_shape[0] - pred.shape[0]), (0, max_shape[1] - pred.shape[1]), (0, max_shape[2] - pred.shape[2]))) for pred in preds_stack]
refs_stack = [np.pad(ref, ((0, max_shape[0] - ref.shape[0]), (0, max_shape[1] - ref.shape[1]), (0, max_shape[2] - ref.shape[2]))) for ref in refs_stack]

# stack the images
preds_stacked = np.stack(preds_stack, axis=-1)
refs_stacked = np.stack(refs_stack, axis=-1)

# create a new file name for reference and prediction
pred_fname = os.path.join(os.path.dirname(preds_per_sub_ses[0]), f'{sub_ses}_preds_stack.nii.gz')
ref_fname = os.path.join(os.path.dirname(refs_per_sub_ses[0]), f'{sub_ses}_refs_stack.nii.gz')

metrics_dict = {'reference': ref_fname, 'prediction': pred_fname}
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(preds_stacked, refs_stacked, args.metrics, metrics_dict)

# append the dictionary to the output list
output_list.append(metrics_dict)

elif args.mask_type == 'stitched':
# Loop over the subjects
for i in tqdm(range(len(reference_files)), desc='Computing metrics for each subject'):

# Load nifti images
prediction_data = load_nifti_image(prediction_files[i])
reference_data = load_nifti_image(reference_files[i])

# append entry into the output_list to store the metrics for the current subject
metrics_dict = {'reference': reference_files[i], 'prediction': prediction_files[i]}
# Compute metrics for each subject
metrics_dict = compute_metrics_single_subject(prediction_data, reference_data, args.metrics, metrics_dict)

# Append the output dictionary (representing a single reference-prediction pair per subject) to the
# output_list
output_list.append(metrics_dict)

# Args.prediction and args.reference are paths nii.gz files from a SINGLE subject
else:
metrics_dict = compute_metrics_single_subject(args.prediction, args.reference, args.metrics)
# Load nifti images
prediction_data = load_nifti_image(args.prediction)
reference_data = load_nifti_image(args.reference)

metrics_dict = {'reference': args.reference, 'prediction': args.prediction}
metrics_dict = compute_metrics_single_subject(prediction_data, reference_data, args.metrics, metrics_dict)

# Append the output dictionary (representing a single reference-prediction pair per subject) to the output_list
output_list.append(metrics_dict)

# Convert JSON data to pandas DataFrame
df = build_output_dataframe(output_list)

# create a separate dataframe for columns where EmptyRef and EmptyPred is True
df_empty_masks = df[(df['EmptyRef'] == True) & (df['EmptyPred'] == True)]

# keep only the rows where either pred or ref is non-empty or both are non-empty
df = df[(df['EmptyRef'] == False) | (df['EmptyPred'] == False)]

# Compute mean and standard deviation of metrics across all subjects
df_mean = (df.drop(columns=['reference', 'prediction', 'EmptyRef', 'EmptyPred']).groupby('label').
agg(['mean', 'std']).reset_index())

# Convert multi-index to flat index
df_mean.columns = ['_'.join(col).strip() for col in df_mean.columns.values]
# Rename column `label_` back to `label`
df_mean.rename(columns={'label_': 'label'}, inplace=True)

# Rename columns
df.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True)
df_mean.rename(columns={metric: METRICS_TO_NAME[metric] for metric in METRICS_TO_NAME}, inplace=True)
Expand All @@ -272,4 +353,4 @@ def main():


if __name__ == '__main__':
main()
main()
Loading