diff --git a/.gitignore b/.gitignore index 600d2d3..6d0ee45 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ -.vscode \ No newline at end of file +.vscode +.DS_Store \ No newline at end of file diff --git a/scripts/register_AMU_to_PAM50.py b/scripts/register_AMU_to_PAM50.py new file mode 100644 index 0000000..0d6172a --- /dev/null +++ b/scripts/register_AMU_to_PAM50.py @@ -0,0 +1,331 @@ + +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Reproducible registration script (no CLI args). +Pipeline: + 1) Step-0 (label-based, Tx_Ty_Tz) AMU -> PAM50 to get initial rigid/translation alignment. + 2) Apply step-0 warp to AMU images to bring them into PAM50 space. + 3) Extend AMU volumes in PAM50 space by copying top/bottom slices to match PAM50 cord extent. + 4) Global slicewise registration (SCT) on extended volumes: + sct_register_multimodal step=1, type=seg, algo=slicereg, poly=2 (standard coarse alignment) + 5) Apply global warp (warp1 โˆ˜ file_warp0) to produce baseline registered AMU_T2* and FILE_AMU_GM. + 6) Detect extended Z-slices and run **per-slice isct_antsRegistration** ONLY on those slices + "de proche en proche", initializing each slice with previous transform. + Overwrite the corresponding slices in the baseline outputs. + 7) Symmetrize and threshold (T2*). + +Assumptions: +- SCT binaries in PATH: sct_label_utils, sct_register_multimodal, sct_apply_transfo, sct_maths +- symmetrize_cord_segmentation.py is located in the SAME directory as this script +- Edit CONFIG below, then just run the script. +""" +import os +import subprocess +from pathlib import Path +import nibabel as nib +import numpy as np + +# ========================= +# CONFIG (edit once) +# ========================= +SCT_DIR = os.environ.get("SCT_DIR", "/opt/sct") # Adjust if not using env var +FILE_AMU_T2S = Path(os.path.expanduser("~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz")) +FILE_AMU_GM = Path(os.path.expanduser("~/Desktop/MNI-POLY-AMU/AMU15_GW_sym.nii.gz")) # GM+WM segmentation (moving seg) + +FILE_PAM50_T2 = Path(f"{SCT_DIR}/data/PAM50/template/PAM50_t2.nii.gz") +FILE_PAM50_SEG = Path(f"{SCT_DIR}/data/PAM50/template/PAM50_cord.nii.gz") + +# Label coordinates (x,y,z,val) for step-0 label-based alignment +LABEL_AMU = (75, 75, 965, 1) +LABEL_PAM = (70, 70, 959, 1) + +# I/O +OUTDIR = Path("./results") +QCDIR = (OUTDIR / "qc").resolve() + +HERE = Path(__file__).resolve().parent +SYM_SCRIPT = HERE / "symmetrize_cord_segmentation.py" + +def run(cmd, check=True): + print("+", " ".join(map(str, cmd)), flush=True) + res = subprocess.run(list(map(str, cmd)), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) + print(res.stdout) + if check and res.returncode != 0: + raise RuntimeError(f"Command failed: {' '.join(map(str, cmd))}") + return res + +def nz_mask_per_slice(vol_path: Path): + img = nib.load(str(vol_path)) + data = img.get_fdata() + if data.ndim != 3: + raise ValueError("Expected a 3D volume") + Z = data.shape[2] + mask = np.zeros(Z, dtype=bool) + for z in range(Z): + mask[z] = np.any(data[:, :, z] != 0) + return mask + +def copy_edge_slices_to_match(moving_path: Path, ref_path: Path, out_path: Path): + """ + Make the moving image's Z dimension match the reference by either: + - Cropping (if moving is longer in Z) + - Padding and copying edge slices (if moving is shorter in Z) + Alignment uses NZ-span centers to preserve the cord region. + """ + img_m = nib.load(str(moving_path)) + img_r = nib.load(str(ref_path)) + data_m = img_m.get_fdata() + data_r = img_r.get_fdata() + + if data_m.ndim != 3 or data_r.ndim != 3: + raise ValueError("Input images must be 3D.") + + def nz_bounds(arr): + proj = np.sum(np.sum(arr != 0, axis=0), axis=0) + nz = np.where(proj > 0)[0] + if nz.size == 0: + return 0, arr.shape[2] - 1 + return int(nz.min()), int(nz.max()) + + zmin_m, zmax_m = nz_bounds(data_m) + zmin_r, zmax_r = nz_bounds(data_r) + z_m = data_m.shape[2] + z_r = data_r.shape[2] + + c_m = 0.5 * (zmin_m + zmax_m) + + if z_m > z_r: + # Crop moving to match reference length + start = int(round(c_m - z_r / 2.0)) + start = max(0, min(start, z_m - z_r)) + data_ext = data_m[:, :, start:start + z_r].copy() + elif z_m < z_r: + # Pad moving to match reference length, then copy edges + data_ext = np.zeros((data_m.shape[0], data_m.shape[1], z_r), dtype=data_m.dtype) + insert_start = int(round((z_r / 2.0) - (c_m))) + insert_start = max(0, min(insert_start, z_r - z_m)) + data_ext[:, :, insert_start:insert_start + z_m] = data_m + + proj = np.any(np.any(data_ext != 0, axis=0), axis=0) + first_nz = np.argmax(proj) + if first_nz > 0: + data_ext[:, :, :first_nz] = np.repeat(data_ext[:, :, first_nz:first_nz + 1], first_nz, axis=2) + proj = np.any(np.any(data_ext != 0, axis=0), axis=0) + last_nz = len(proj) - 1 - np.argmax(proj[::-1]) + if last_nz < data_ext.shape[2] - 1: + fill_len = data_ext.shape[2] - 1 - last_nz + data_ext[:, :, last_nz + 1:] = np.repeat(data_ext[:, :, last_nz:last_nz + 1], fill_len, axis=2) + else: + data_ext = data_m.copy() + proj = np.any(np.any(data_ext != 0, axis=0), axis=0) + first_nz = np.argmax(proj) + last_nz = len(proj) - 1 - np.argmax(proj[::-1]) + if first_nz > 0: + data_ext[:, :, :first_nz] = np.repeat(data_ext[:, :, first_nz:first_nz + 1], first_nz, axis=2) + if last_nz < data_ext.shape[2] - 1: + fill_len = data_ext.shape[2] - 1 - last_nz + data_ext[:, :, last_nz + 1:] = np.repeat(data_ext[:, :, last_nz:last_nz + 1], fill_len, axis=2) + + nib.Nifti1Image(data_ext, img_m.affine, img_m.header).to_filename(str(out_path)) + return out_path + + +# ========================================================================== +# SCRIPT STARTS HERE +# ========================================================================== + +OUTDIR.mkdir(parents=True, exist_ok=True) +QCDIR.mkdir(parents=True, exist_ok=True) +workdir = OUTDIR / "work" +workdir.mkdir(parents=True, exist_ok=True) +cwd = os.getcwd() +os.chdir(workdir) + +# Create pointwise labels to bring AMU template in PAM50 space +file_label_amu = Path(f"label_AMU.nii.gz") +file_label_pam = Path(f"label_PAM50.nii.gz") +run(["sct_label_utils", "-i", FILE_AMU_T2S, "-create", ",".join(map(str, LABEL_AMU)), "-o", file_label_amu]) +run(["sct_label_utils", "-i", FILE_PAM50_T2, "-create", ",".join(map(str, LABEL_PAM)), "-o", file_label_pam]) + +# Label-based registration (Tx_Ty_Tz), followed by non-linear registration using cord segmentation +run([ + "sct_register_multimodal", + "-i", FILE_AMU_T2S, + "-iseg", FILE_AMU_GM, + "-ilabel", file_label_amu, + "-d", FILE_PAM50_T2, + "-dseg", FILE_PAM50_SEG, + "-dlabel", file_label_pam, + "-param", "step=0,type=label,dof=Tx_Ty_Tz:step=1,type=seg,algo=slicereg,poly=2", + "-qc", QCDIR, +]) +file_amu_t2s_step0 = Path(FILE_AMU_T2S.name.removesuffix('.nii.gz') + "_step0.nii.gz") +# Rename for clarity +os.rename(Path(FILE_AMU_T2S.name.removesuffix('.nii.gz') + "_reg.nii.gz"), file_amu_t2s_step0) + +# Locate warp +srcbase0 = Path(FILE_AMU_T2S).name.replace(".nii.gz","") +dstbase = Path(FILE_PAM50_T2).name.replace(".nii.gz","") +file_warp0 = Path(f"warp_{srcbase0}2{dstbase}.nii.gz") +if not file_warp0.exists(): + raise FileNotFoundError("Could not find {file_warp0}") + +# Apply warp to AMU cord segmentation +file_amu_g_step0 = Path(FILE_AMU_GM.name.removesuffix('.nii.gz') + "_step0.nii.gz") +run(["sct_apply_transfo", "-i", FILE_AMU_GM, "-d", FILE_PAM50_T2, "-w", file_warp0, "-x", "linear", "-o", file_amu_g_step0]) + +# Extend top/bottom mask to cover the full PAM50 space +print("==> Extending PAM50-space AMU images along Z to match PAM50 cord segmentation extent.") +file_amu_t2s_step0_ext = Path(file_amu_t2s_step0.name.removesuffix('.nii.gz') + "_ext.nii.gz") +file_amu_g_step0_ext = Path(file_amu_g_step0.name.removesuffix('.nii.gz') + "_ext.nii.gz") +copy_edge_slices_to_match(file_amu_t2s_step0, FILE_PAM50_SEG, file_amu_t2s_step0_ext) +copy_edge_slices_to_match(file_amu_g_step0, FILE_PAM50_SEG, file_amu_g_step0_ext) + +# TODO: the code below can be simplified a lot, if we don't care about the "truly extended" slices. +# Also, we might replace the bottom extension with actual concatenation of another set of images. +# ---------- Per-slice refinement ONLY on truly extended slices ---------- +def nz_mask_per_slice_data(arr): + Z = arr.shape[2] + mask = np.zeros(Z, dtype=bool) + for z in range(Z): + mask[z] = np.any(arr[:, :, z] != 0) + return mask + +# Extended detection via step0 vs step0_ext (GM) +mask_step0 = nz_mask_per_slice(file_amu_g_step0) +mask_ext = nz_mask_per_slice(file_amu_g_step0_ext) +Z = mask_ext.shape[0] +extended = np.logical_and(~mask_step0, mask_ext) +nz_indices = np.where(mask_step0)[0] +if nz_indices.size == 0: + z_core_min, z_core_max = 0, -1 +else: + z_core_min, z_core_max = int(nz_indices.min()), int(nz_indices.max()) +bottom_ext = np.where(extended[:z_core_min])[0] if z_core_min >= 1 else np.array([], dtype=int) +top_ext = (z_core_max + 1) + np.where(extended[z_core_max + 1:])[0] if z_core_max < Z - 1 else np.array([], dtype=int) + +print(f"Core NZ span: [{z_core_min}, {z_core_max}]") +print(f"Bottom extended slices: {bottom_ext.tolist()}") +print(f"Top extended slices: {top_ext.tolist()}") + +# Load baseline registered 3D outputs and the extended volumes for picking slices +ref_fix_seg = nib.load(str(FILE_PAM50_SEG)) +fix_seg_3d = ref_fix_seg.get_fdata() +gm_ext_3d = nib.load(str(file_amu_g_step0_ext)).get_fdata() +t2s_ext_3d = nib.load(str(file_amu_t2s_step0_ext)).get_fdata() + +# Helper to run ants on a single slice and overwrite in the baseline outputs +def ants_slice_refine(z, prev_mat=None): + # Build 2D (single-slice) fixed/moving files + fix = workdir / f"fix_seg_z{z:04d}.nii.gz" + mov_g = workdir / f"mov_g_ext_z{z:04d}.nii.gz" + mov_t = workdir / f"mov_t2s_ext_z{z:04d}.nii.gz" + + fix_slice = nib.load(str(FILE_PAM50_SEG)).get_fdata() + gm_ext_3d = nib.load(str(file_amu_g_step0_ext)).get_fdata() + t2_ext_3d = nib.load(str(file_amu_t2s_step0_ext)).get_fdata() + + # write as single-slice 2D + # aff2d = affine_3d_to_2d(nib.load(str(FILE_PAM50_SEG)).affine) + # nib.Nifti1Image(fix_slice[:, :, z], aff2d, header_3d_to_2d(nib.load(str(FILE_PAM50_SEG)).header, np.shape(fix_slice[:, :, z]), aff2d)).to_filename(str(fix)) + nib.Nifti1Image(fix_slice[:, :, z], nib.load(str(FILE_PAM50_SEG)).affine, nib.load(str(FILE_PAM50_SEG)).header).to_filename(str(fix)) + nib.Nifti1Image(gm_ext_3d[:, :, z], nib.load(str(file_amu_g_step0_ext)).affine, nib.load(str(file_amu_g_step0_ext)).header).to_filename(str(mov_g)) + nib.Nifti1Image(t2_ext_3d[:, :, z], nib.load(str(file_amu_t2s_step0_ext)).affine, nib.load(str(file_amu_t2s_step0_ext)).header).to_filename(str(mov_t)) + + # Run 2D rigid ANTs (use previous sliceโ€™s affine as init if provided) + out_prefix = workdir / f"ants_z{z:04d}_" + cmd = ["isct_antsRegistration", + "-d", "2", + # Stage 1: rigid + "-t", "Rigid[0.1]", + "-m", f"MeanSquares[{fix},{mov_g},1,4]", + "-c", "50x20", + "-s", "1x0", + "-f", "2x1", + # Stage 2: bsplineSyN non-linear refinement + "-t", "BSplineSyN[0.1,26,0,3]", + "-m", f"MeanSquares[{fix},{mov_g},1,4]", + "-c", "50x20", + "-s", "1x0", + "-f", "2x1"] + if prev_mat is not None: + cmd += ["-r", str(prev_mat)] + cmd += ["-o", str(out_prefix)] + run(cmd) + + mat = Path(str(out_prefix) + "0GenericAffine.mat") + if not mat.exists(): + raise FileNotFoundError(f"ants transform not found for slice z={z}: {mat}") + warp = Path(str(out_prefix) + "1Warp.nii.gz") + if not warp.exists(): + raise FileNotFoundError(f"ants transform not found for slice z={z}: {warp}") + + # Apply to GM (NN) and T2* (linear) -- ANTs will write 2-D images + out_g = workdir / f"amu_g_refined_z{z:04d}.nii.gz" + out_t = workdir / f"amu_t2s_refined_z{z:04d}.nii.gz" + run(["isct_antsApplyTransforms", "-d", "2", + "-i", mov_g, "-r", fix, "-t", mat, warp, "-o", out_g, "-n", "Linear"]) + run(["isct_antsApplyTransforms", "-d", "2", + "-i", mov_t, "-r", fix, "-t", mat, warp, "-o", out_t, "-n", "Linear"]) + # Copy header info (spacing) from input to output + # TODO replace code below with nibabel for faster execution + run(["sct_image", "-i", str(fix), "-copy-header", str(out_g), "-o", str(out_g)]) + run(["sct_image", "-i", str(fix), "-copy-header", str(out_t), "-o", str(out_t)]) + + # Load refined 2-D outputs robustly (2-D or (X,Y,1)) + def load2d(path): + arr = nib.load(str(path)).get_fdata() + return arr[:, :, 0] if arr.ndim == 3 else arr + + g_w = load2d(out_g) # (X,Y) + t_w = load2d(out_t) # (X,Y) + + # Overwrite slice z in the *baseline* registered 3-D outputs + ref_g_img = nib.load(str(file_amu_g_step0_ext)) + ref_t_img = nib.load(str(file_amu_t2s_step0_ext)) + g3d = ref_g_img.get_fdata() + t3d = ref_t_img.get_fdata() + g3d[:, :, z] = g_w + t3d[:, :, z] = t_w + nib.Nifti1Image(g3d, ref_g_img.affine, ref_g_img.header).to_filename(str(file_amu_g_step0_ext)) + nib.Nifti1Image(t3d, ref_t_img.affine, ref_t_img.header).to_filename(str(file_amu_t2s_step0_ext)) + + return mat + +# Bottom chain (seed at z_core_min), then go downward on bottom_ext +prev_mat = None +if bottom_ext.size > 0 and z_core_min <= z_core_max: + prev_mat = ants_slice_refine(z_core_min, prev_mat=None) + for z in range(z_core_min - 1, -1, -1): + if z in bottom_ext: + prev_mat = ants_slice_refine(z, prev_mat=prev_mat) + +# Top chain (seed at z_core_max), then go upward on top_ext +prev_mat = None +if top_ext.size > 0 and z_core_min <= z_core_max: + prev_mat = ants_slice_refine(z_core_max, prev_mat=None) + for z in range(z_core_max + 1, Z): + if z in top_ext: + prev_mat = ants_slice_refine(z, prev_mat=prev_mat) + +# Symmetrize and threshold +amu_g_sym = OUTDIR / (amu_g_reg.stem + "_sym.nii.gz") +amu_t2s_sym = OUTDIR / (amu_t2s_reg.stem + "_sym.nii.gz") +run(["python3", SYM_SCRIPT, "-i", amu_g_reg, "--dtype", "float32", "--mode", "average", "-o", amu_g_sym]) +run(["python3", SYM_SCRIPT, "-i", amu_t2s_reg, "--dtype", "float32", "--mode", "average", "-o", amu_t2s_sym]) + +amu_t2s_thr = OUTDIR / (amu_t2s_sym.stem + "_thr.nii.gz") +run(["sct_maths", "-i", amu_t2s_sym, "-thr", "0", "-type", "uint16", "-o", amu_t2s_thr]) + +print("\n=== Outputs ===") +print(f"QC dir: {QCDIR}") +print(f"Step-0 warp: {file_warp0}") +print(f"Step-1 warp: {warp1}") +print(f"AMU GM reg: {amu_g_reg}") +print(f"AMU GM sym: {amu_g_sym}") +print(f"AMU T2* reg: {amu_t2s_reg}") +print(f"AMU T2* sym: {amu_t2s_sym}") +print(f"AMU T2* sym thr16: {amu_t2s_thr}") +print(f"Work dir (intermed.): {workdir}") + diff --git a/scripts/register_AMU_to_PAM50.sh b/scripts/register_AMU_to_PAM50.sh new file mode 100644 index 0000000..988523a --- /dev/null +++ b/scripts/register_AMU_to_PAM50.sh @@ -0,0 +1,18 @@ +# TODO: Download AMU-Poly-MNI zip files provided by Virginie Callot +cd $SCT_DIR/data/PAM50/template +# Create label on MNI-Poly-AMU to identify the z (top part of the visible template). The XY coordinate does not matter too much since pipeline will be followed by centermass registration. +# TODO: remove hardcoding of paths +sct_label_utils -i ~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz -create 75,75,965,1 -o ~/Desktop/MNI-POLY-AMU/label_AMU15.nii.gz +# Create associated label in the PAM50 +sct_label_utils -i PAM50_t2.nii.gz -create 70,70,959,1 -o label_PAM50.nii.gz +# Register AMU15 --> PAM50 +sct_register_multimodal -i ~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz -iseg ~/Desktop/MNI-POLY-AMU/AMU15_GW_sym.nii.gz -ilabel ~/Desktop/MNI-POLY-AMU/label_AMU15.nii.gz -d PAM50_t2.nii.gz -dseg PAM50_t2_seg.nii.gz -dlabel label_PAM50.nii.gz -param step=0,type=label,dof=Tx_Ty_Tz:step=1,type=seg,algo=slicereg,poly=2 -qc qc +# Apply the warping to the AMU15 template and its segmentation +sct_apply_transfo -i ~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz -d PAM50_t2.nii.gz -w warp_AMU15_T2star_sym2PAM50_t2.nii.gz -x linear +sct_apply_transfo -i ~/Desktop/MNI-POLY-AMU/AMU15_G_sym.nii.gz -d PAM50_t2.nii.gz -w warp_AMU15_T2star_sym2PAM50_t2.nii.gz -x linear +# Symmetrize the warped objects +# TODO: remove hardcoding of paths +python3 ~/code/PAM50/scripts/symmetrize_cord_segmentation.py -i AMU15_G_sym_reg.nii.gz --dtype float32 --mode average +python3 ~/code/PAM50/scripts/symmetrize_cord_segmentation.py -i AMU15_T2star_sym_reg.nii.gz --dtype float32 --mode average +# Remove negatives values on the T2*-weighted image by thresholding, and output as UINT16 type (like for PAM50_t2) +sct_maths -i AMU15_T2star_sym_reg_sym.nii.gz -thr 0 -type uint16 -o AMU15_T2star_sym_reg_sym_thr.nii.gz diff --git a/scripts/symmetrize_cord_segmentation.py b/scripts/symmetrize_cord_segmentation.py index 29f8fef..1c79680 100644 --- a/scripts/symmetrize_cord_segmentation.py +++ b/scripts/symmetrize_cord_segmentation.py @@ -1,40 +1,131 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -# +# # This script creates a symmetrical image by copying the information from the right side of the image # to the left side. -# -# It is particularly useful when manually correcting a spinal cord segmentation, because only the right -# part needs to be corrected, and then this script is run to correct the left part. -# -# For more context, see: https://github.com/spinalcordtoolbox/PAM50/issues/19 -# -# How to run: -# cd where this script is located and run: -# python symmetrize_cord_segmentation.py -# -# Author: Julien Cohen-Adad +# +# Usage: +# python symmetrize_cord_segmentation.py -i input.nii.gz [-o output.nii.gz] [-t uint8] +# +# Author: Julien Cohen-Adad (adapted) +import argparse +import os import numpy as np import nibabel as nib -# Open PAM50 spinal cord segmentation -nii_seg = nib.load("../template/PAM50_cord.nii.gz") -data_seg = nii_seg.get_fdata() +def str_to_dtype(dtype_str): + try: + return np.dtype(dtype_str) + except TypeError: + raise argparse.ArgumentTypeError(f"Unsupported dtype: {dtype_str}") -# Symmetrize image by copying the right to the left -data_seg[71:, ...] = np.flip(data_seg[:70, ...], axis=0) -# Use proper dtype -data_seg = np.uint8(data_seg) -header_seg = nii_seg.header.copy() -header_seg.set_data_dtype(np.uint8) +def make_default_output_name(input_fname: str) -> str: + base, ext = os.path.splitext(input_fname) + if ext == ".gz": # handle .nii.gz + base, ext2 = os.path.splitext(base) + ext = ext2 + ext + return base + "_sym" + ext -# Save file -# nii_seg_new = copy.deepcopy(nii_seg) -nii_seg_new = nib.Nifti1Image(data_seg, nii_seg.affine, header_seg) -fname_out = "PAM50_cord_new.nii.gz" -nib.save(nii_seg_new, fname_out) -print(f"Done! ๐ŸŽ‰ \nFile created: {fname_out}") +def symmetrize(data: np.ndarray, axis: int = 0, mode: str = "copy") -> np.ndarray: + """ + Symmetrize `data` along `axis`. + Modes: + - "copy": mirror the right half to the left AND mirror back to the right โ†’ perfectly symmetric. + - "average": average left & mirrored-right, then write that average to BOTH halves. + Center slice (when odd size) is preserved as-is. + """ + n = data.shape[axis] + mid = n // 2 + out = data.copy() + + # convenient index builders + def sl(start, stop): + idx = [slice(None)] * data.ndim + idx[axis] = slice(start, stop) + return tuple(idx) + + def flip_along(arr): + return np.flip(arr, axis=axis) + + if n % 2 == 0: + left = data[sl(0, mid)] # length mid + right = data[sl(mid, n)] # length mid + + if mode == "copy": + # define template as "right" and write both halves symmetrically + templ = right + out[sl(0, mid)] = flip_along(templ) + out[sl(mid, n)] = templ + elif mode == "average": + avg = np.nanmean(np.stack([left, flip_along(right)], axis=0), axis=0) + out[sl(0, mid)] = avg + out[sl(mid, n)] = flip_along(avg) + else: + raise ValueError("mode must be 'copy' or 'average'") + + else: + # odd: left [0:mid], center [mid], right [mid+1:n] (left/right length mid) + left = data[sl(0, mid)] + center = data[sl(mid, mid+1)] + right = data[sl(mid+1, n)] + + if mode == "copy": + templ = right + out[sl(0, mid)] = flip_along(templ) + out[sl(mid, mid+1)] = center # keep center + out[sl(mid+1, n)] = templ + elif mode == "average": + avg = np.nanmean(np.stack([left, flip_along(right)], axis=0), axis=0) + out[sl(0, mid)] = avg + out[sl(mid, mid+1)] = center # keep center + out[sl(mid+1, n)] = flip_along(avg) + else: + raise ValueError("mode must be 'copy' or 'average'") + + return out + + +def main(): + parser = argparse.ArgumentParser(description="Symmetrize a NIfTI image along axis 0.") + parser.add_argument("-i", "--input", required=True, help="Path to input NIfTI file") + parser.add_argument("-o", "--output", default=None, + help="Path to output NIfTI file (default: input_sym.nii.gz)") + parser.add_argument("-t", "--dtype", type=str_to_dtype, default=np.uint8, + help="Output data type (e.g., uint8, int16, float32). Default: uint8") + parser.add_argument("--mode", choices=["copy", "average"], default="copy", + help="Symmetrization mode. Default: copy") + # Optional: expose axis if needed later; default stays 0 + parser.add_argument("--axis", type=int, default=0, + help="Axis to symmetrize along (0-based). Default: 0") + args = parser.parse_args() + + # Load input + nii = nib.load(args.input) + data = nii.get_fdata() # work in float; we'll cast later + + # Symmetrize + data_sym = symmetrize(data, axis=args.axis, mode=args.mode) + + # If integer dtype requested, round to nearest before casting to avoid asymmetry from truncation + if np.issubdtype(args.dtype, np.integer): + data_sym = np.rint(data_sym) + + # Cast & preserve header (qform/sform, zooms, slope/intercept, etc.) + data_sym = data_sym.astype(args.dtype) + header = nii.header.copy() + header.set_data_dtype(args.dtype) + + # Output filename + output_fname = args.output if args.output else make_default_output_name(args.input) + + # Save + nib.save(nib.Nifti1Image(data_sym, nii.affine, header), output_fname) + print(f"โœ… Done! Saved: {output_fname} | dtype={args.dtype} | mode={args.mode} | axis={args.axis}") + + +if __name__ == "__main__": + main()