Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.vscode
.vscode
.DS_Store
331 changes: 331 additions & 0 deletions scripts/register_AMU_to_PAM50.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Reproducible registration script (no CLI args).
Pipeline:
1) Step-0 (label-based, Tx_Ty_Tz) AMU -> PAM50 to get initial rigid/translation alignment.
2) Apply step-0 warp to AMU images to bring them into PAM50 space.
3) Extend AMU volumes in PAM50 space by copying top/bottom slices to match PAM50 cord extent.
4) Global slicewise registration (SCT) on extended volumes:
sct_register_multimodal step=1, type=seg, algo=slicereg, poly=2 (standard coarse alignment)
5) Apply global warp (warp1 ∘ file_warp0) to produce baseline registered AMU_T2* and FILE_AMU_GM.
6) Detect extended Z-slices and run **per-slice isct_antsRegistration** ONLY on those slices
"de proche en proche", initializing each slice with previous transform.
Overwrite the corresponding slices in the baseline outputs.
7) Symmetrize and threshold (T2*).

Assumptions:
- SCT binaries in PATH: sct_label_utils, sct_register_multimodal, sct_apply_transfo, sct_maths
- symmetrize_cord_segmentation.py is located in the SAME directory as this script
- Edit CONFIG below, then just run the script.
"""
import os
import subprocess
from pathlib import Path
import nibabel as nib
import numpy as np

# =========================
# CONFIG (edit once)
# =========================
SCT_DIR = os.environ.get("SCT_DIR", "/opt/sct") # Adjust if not using env var
FILE_AMU_T2S = Path(os.path.expanduser("~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz"))
FILE_AMU_GM = Path(os.path.expanduser("~/Desktop/MNI-POLY-AMU/AMU15_GW_sym.nii.gz")) # GM+WM segmentation (moving seg)

FILE_PAM50_T2 = Path(f"{SCT_DIR}/data/PAM50/template/PAM50_t2.nii.gz")
FILE_PAM50_SEG = Path(f"{SCT_DIR}/data/PAM50/template/PAM50_cord.nii.gz")

# Label coordinates (x,y,z,val) for step-0 label-based alignment
LABEL_AMU = (75, 75, 965, 1)
LABEL_PAM = (70, 70, 959, 1)

# I/O
OUTDIR = Path("./results")
QCDIR = (OUTDIR / "qc").resolve()

HERE = Path(__file__).resolve().parent
SYM_SCRIPT = HERE / "symmetrize_cord_segmentation.py"

def run(cmd, check=True):
print("+", " ".join(map(str, cmd)), flush=True)
res = subprocess.run(list(map(str, cmd)), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
print(res.stdout)
if check and res.returncode != 0:
raise RuntimeError(f"Command failed: {' '.join(map(str, cmd))}")
return res

def nz_mask_per_slice(vol_path: Path):
img = nib.load(str(vol_path))
data = img.get_fdata()
if data.ndim != 3:
raise ValueError("Expected a 3D volume")
Z = data.shape[2]
mask = np.zeros(Z, dtype=bool)
for z in range(Z):
mask[z] = np.any(data[:, :, z] != 0)
return mask

def copy_edge_slices_to_match(moving_path: Path, ref_path: Path, out_path: Path):
"""
Make the moving image's Z dimension match the reference by either:
- Cropping (if moving is longer in Z)
- Padding and copying edge slices (if moving is shorter in Z)
Alignment uses NZ-span centers to preserve the cord region.
"""
img_m = nib.load(str(moving_path))
img_r = nib.load(str(ref_path))
data_m = img_m.get_fdata()
data_r = img_r.get_fdata()

if data_m.ndim != 3 or data_r.ndim != 3:
raise ValueError("Input images must be 3D.")

def nz_bounds(arr):
proj = np.sum(np.sum(arr != 0, axis=0), axis=0)
nz = np.where(proj > 0)[0]
if nz.size == 0:
return 0, arr.shape[2] - 1
return int(nz.min()), int(nz.max())

zmin_m, zmax_m = nz_bounds(data_m)
zmin_r, zmax_r = nz_bounds(data_r)
z_m = data_m.shape[2]
z_r = data_r.shape[2]

c_m = 0.5 * (zmin_m + zmax_m)

if z_m > z_r:
# Crop moving to match reference length
start = int(round(c_m - z_r / 2.0))
start = max(0, min(start, z_m - z_r))
data_ext = data_m[:, :, start:start + z_r].copy()
elif z_m < z_r:
# Pad moving to match reference length, then copy edges
data_ext = np.zeros((data_m.shape[0], data_m.shape[1], z_r), dtype=data_m.dtype)
insert_start = int(round((z_r / 2.0) - (c_m)))
insert_start = max(0, min(insert_start, z_r - z_m))
data_ext[:, :, insert_start:insert_start + z_m] = data_m

proj = np.any(np.any(data_ext != 0, axis=0), axis=0)
first_nz = np.argmax(proj)
if first_nz > 0:
data_ext[:, :, :first_nz] = np.repeat(data_ext[:, :, first_nz:first_nz + 1], first_nz, axis=2)
proj = np.any(np.any(data_ext != 0, axis=0), axis=0)
last_nz = len(proj) - 1 - np.argmax(proj[::-1])
if last_nz < data_ext.shape[2] - 1:
fill_len = data_ext.shape[2] - 1 - last_nz
data_ext[:, :, last_nz + 1:] = np.repeat(data_ext[:, :, last_nz:last_nz + 1], fill_len, axis=2)
else:
data_ext = data_m.copy()
proj = np.any(np.any(data_ext != 0, axis=0), axis=0)
first_nz = np.argmax(proj)
last_nz = len(proj) - 1 - np.argmax(proj[::-1])
if first_nz > 0:
data_ext[:, :, :first_nz] = np.repeat(data_ext[:, :, first_nz:first_nz + 1], first_nz, axis=2)
if last_nz < data_ext.shape[2] - 1:
fill_len = data_ext.shape[2] - 1 - last_nz
data_ext[:, :, last_nz + 1:] = np.repeat(data_ext[:, :, last_nz:last_nz + 1], fill_len, axis=2)

nib.Nifti1Image(data_ext, img_m.affine, img_m.header).to_filename(str(out_path))
return out_path


# ==========================================================================
# SCRIPT STARTS HERE
# ==========================================================================

OUTDIR.mkdir(parents=True, exist_ok=True)
QCDIR.mkdir(parents=True, exist_ok=True)
workdir = OUTDIR / "work"
workdir.mkdir(parents=True, exist_ok=True)
cwd = os.getcwd()
os.chdir(workdir)

# Create pointwise labels to bring AMU template in PAM50 space
file_label_amu = Path(f"label_AMU.nii.gz")
file_label_pam = Path(f"label_PAM50.nii.gz")
run(["sct_label_utils", "-i", FILE_AMU_T2S, "-create", ",".join(map(str, LABEL_AMU)), "-o", file_label_amu])
run(["sct_label_utils", "-i", FILE_PAM50_T2, "-create", ",".join(map(str, LABEL_PAM)), "-o", file_label_pam])

# Label-based registration (Tx_Ty_Tz), followed by non-linear registration using cord segmentation
run([
"sct_register_multimodal",
"-i", FILE_AMU_T2S,
"-iseg", FILE_AMU_GM,
"-ilabel", file_label_amu,
"-d", FILE_PAM50_T2,
"-dseg", FILE_PAM50_SEG,
"-dlabel", file_label_pam,
"-param", "step=0,type=label,dof=Tx_Ty_Tz:step=1,type=seg,algo=slicereg,poly=2",
"-qc", QCDIR,
])
file_amu_t2s_step0 = Path(FILE_AMU_T2S.name.removesuffix('.nii.gz') + "_step0.nii.gz")
# Rename for clarity
os.rename(Path(FILE_AMU_T2S.name.removesuffix('.nii.gz') + "_reg.nii.gz"), file_amu_t2s_step0)

# Locate warp
srcbase0 = Path(FILE_AMU_T2S).name.replace(".nii.gz","")
dstbase = Path(FILE_PAM50_T2).name.replace(".nii.gz","")
file_warp0 = Path(f"warp_{srcbase0}2{dstbase}.nii.gz")
if not file_warp0.exists():
raise FileNotFoundError("Could not find {file_warp0}")

# Apply warp to AMU cord segmentation
file_amu_g_step0 = Path(FILE_AMU_GM.name.removesuffix('.nii.gz') + "_step0.nii.gz")
run(["sct_apply_transfo", "-i", FILE_AMU_GM, "-d", FILE_PAM50_T2, "-w", file_warp0, "-x", "linear", "-o", file_amu_g_step0])

# Extend top/bottom mask to cover the full PAM50 space
print("==> Extending PAM50-space AMU images along Z to match PAM50 cord segmentation extent.")
file_amu_t2s_step0_ext = Path(file_amu_t2s_step0.name.removesuffix('.nii.gz') + "_ext.nii.gz")
file_amu_g_step0_ext = Path(file_amu_g_step0.name.removesuffix('.nii.gz') + "_ext.nii.gz")
copy_edge_slices_to_match(file_amu_t2s_step0, FILE_PAM50_SEG, file_amu_t2s_step0_ext)
copy_edge_slices_to_match(file_amu_g_step0, FILE_PAM50_SEG, file_amu_g_step0_ext)

# TODO: the code below can be simplified a lot, if we don't care about the "truly extended" slices.
# Also, we might replace the bottom extension with actual concatenation of another set of images.
# ---------- Per-slice refinement ONLY on truly extended slices ----------
def nz_mask_per_slice_data(arr):
Z = arr.shape[2]
mask = np.zeros(Z, dtype=bool)
for z in range(Z):
mask[z] = np.any(arr[:, :, z] != 0)
return mask

# Extended detection via step0 vs step0_ext (GM)
mask_step0 = nz_mask_per_slice(file_amu_g_step0)
mask_ext = nz_mask_per_slice(file_amu_g_step0_ext)
Z = mask_ext.shape[0]
extended = np.logical_and(~mask_step0, mask_ext)
nz_indices = np.where(mask_step0)[0]
if nz_indices.size == 0:
z_core_min, z_core_max = 0, -1
else:
z_core_min, z_core_max = int(nz_indices.min()), int(nz_indices.max())
bottom_ext = np.where(extended[:z_core_min])[0] if z_core_min >= 1 else np.array([], dtype=int)
top_ext = (z_core_max + 1) + np.where(extended[z_core_max + 1:])[0] if z_core_max < Z - 1 else np.array([], dtype=int)

print(f"Core NZ span: [{z_core_min}, {z_core_max}]")
print(f"Bottom extended slices: {bottom_ext.tolist()}")
print(f"Top extended slices: {top_ext.tolist()}")

# Load baseline registered 3D outputs and the extended volumes for picking slices
ref_fix_seg = nib.load(str(FILE_PAM50_SEG))
fix_seg_3d = ref_fix_seg.get_fdata()
gm_ext_3d = nib.load(str(file_amu_g_step0_ext)).get_fdata()
t2s_ext_3d = nib.load(str(file_amu_t2s_step0_ext)).get_fdata()

# Helper to run ants on a single slice and overwrite in the baseline outputs
def ants_slice_refine(z, prev_mat=None):
# Build 2D (single-slice) fixed/moving files
fix = workdir / f"fix_seg_z{z:04d}.nii.gz"
mov_g = workdir / f"mov_g_ext_z{z:04d}.nii.gz"
mov_t = workdir / f"mov_t2s_ext_z{z:04d}.nii.gz"

fix_slice = nib.load(str(FILE_PAM50_SEG)).get_fdata()
gm_ext_3d = nib.load(str(file_amu_g_step0_ext)).get_fdata()
t2_ext_3d = nib.load(str(file_amu_t2s_step0_ext)).get_fdata()

# write as single-slice 2D
# aff2d = affine_3d_to_2d(nib.load(str(FILE_PAM50_SEG)).affine)
# nib.Nifti1Image(fix_slice[:, :, z], aff2d, header_3d_to_2d(nib.load(str(FILE_PAM50_SEG)).header, np.shape(fix_slice[:, :, z]), aff2d)).to_filename(str(fix))
nib.Nifti1Image(fix_slice[:, :, z], nib.load(str(FILE_PAM50_SEG)).affine, nib.load(str(FILE_PAM50_SEG)).header).to_filename(str(fix))
nib.Nifti1Image(gm_ext_3d[:, :, z], nib.load(str(file_amu_g_step0_ext)).affine, nib.load(str(file_amu_g_step0_ext)).header).to_filename(str(mov_g))
nib.Nifti1Image(t2_ext_3d[:, :, z], nib.load(str(file_amu_t2s_step0_ext)).affine, nib.load(str(file_amu_t2s_step0_ext)).header).to_filename(str(mov_t))

# Run 2D rigid ANTs (use previous slice’s affine as init if provided)
out_prefix = workdir / f"ants_z{z:04d}_"
cmd = ["isct_antsRegistration",
"-d", "2",
# Stage 1: rigid
"-t", "Rigid[0.1]",
"-m", f"MeanSquares[{fix},{mov_g},1,4]",
"-c", "50x20",
"-s", "1x0",
"-f", "2x1",
# Stage 2: bsplineSyN non-linear refinement
"-t", "BSplineSyN[0.1,26,0,3]",
"-m", f"MeanSquares[{fix},{mov_g},1,4]",
"-c", "50x20",
"-s", "1x0",
"-f", "2x1"]
if prev_mat is not None:
cmd += ["-r", str(prev_mat)]
cmd += ["-o", str(out_prefix)]
run(cmd)

mat = Path(str(out_prefix) + "0GenericAffine.mat")
if not mat.exists():
raise FileNotFoundError(f"ants transform not found for slice z={z}: {mat}")
warp = Path(str(out_prefix) + "1Warp.nii.gz")
if not warp.exists():
raise FileNotFoundError(f"ants transform not found for slice z={z}: {warp}")

# Apply to GM (NN) and T2* (linear) -- ANTs will write 2-D images
out_g = workdir / f"amu_g_refined_z{z:04d}.nii.gz"
out_t = workdir / f"amu_t2s_refined_z{z:04d}.nii.gz"
run(["isct_antsApplyTransforms", "-d", "2",
"-i", mov_g, "-r", fix, "-t", mat, warp, "-o", out_g, "-n", "Linear"])
run(["isct_antsApplyTransforms", "-d", "2",
"-i", mov_t, "-r", fix, "-t", mat, warp, "-o", out_t, "-n", "Linear"])
# Copy header info (spacing) from input to output
# TODO replace code below with nibabel for faster execution
run(["sct_image", "-i", str(fix), "-copy-header", str(out_g), "-o", str(out_g)])
run(["sct_image", "-i", str(fix), "-copy-header", str(out_t), "-o", str(out_t)])

# Load refined 2-D outputs robustly (2-D or (X,Y,1))
def load2d(path):
arr = nib.load(str(path)).get_fdata()
return arr[:, :, 0] if arr.ndim == 3 else arr

g_w = load2d(out_g) # (X,Y)
t_w = load2d(out_t) # (X,Y)

# Overwrite slice z in the *baseline* registered 3-D outputs
ref_g_img = nib.load(str(file_amu_g_step0_ext))
ref_t_img = nib.load(str(file_amu_t2s_step0_ext))
g3d = ref_g_img.get_fdata()
t3d = ref_t_img.get_fdata()
g3d[:, :, z] = g_w
t3d[:, :, z] = t_w
nib.Nifti1Image(g3d, ref_g_img.affine, ref_g_img.header).to_filename(str(file_amu_g_step0_ext))
nib.Nifti1Image(t3d, ref_t_img.affine, ref_t_img.header).to_filename(str(file_amu_t2s_step0_ext))

return mat

# Bottom chain (seed at z_core_min), then go downward on bottom_ext
prev_mat = None
if bottom_ext.size > 0 and z_core_min <= z_core_max:
prev_mat = ants_slice_refine(z_core_min, prev_mat=None)
for z in range(z_core_min - 1, -1, -1):
if z in bottom_ext:
prev_mat = ants_slice_refine(z, prev_mat=prev_mat)

# Top chain (seed at z_core_max), then go upward on top_ext
prev_mat = None
if top_ext.size > 0 and z_core_min <= z_core_max:
prev_mat = ants_slice_refine(z_core_max, prev_mat=None)
for z in range(z_core_max + 1, Z):
if z in top_ext:
prev_mat = ants_slice_refine(z, prev_mat=prev_mat)

# Symmetrize and threshold
amu_g_sym = OUTDIR / (amu_g_reg.stem + "_sym.nii.gz")
amu_t2s_sym = OUTDIR / (amu_t2s_reg.stem + "_sym.nii.gz")
run(["python3", SYM_SCRIPT, "-i", amu_g_reg, "--dtype", "float32", "--mode", "average", "-o", amu_g_sym])
run(["python3", SYM_SCRIPT, "-i", amu_t2s_reg, "--dtype", "float32", "--mode", "average", "-o", amu_t2s_sym])

amu_t2s_thr = OUTDIR / (amu_t2s_sym.stem + "_thr.nii.gz")
run(["sct_maths", "-i", amu_t2s_sym, "-thr", "0", "-type", "uint16", "-o", amu_t2s_thr])

print("\n=== Outputs ===")
print(f"QC dir: {QCDIR}")
print(f"Step-0 warp: {file_warp0}")
print(f"Step-1 warp: {warp1}")
print(f"AMU GM reg: {amu_g_reg}")
print(f"AMU GM sym: {amu_g_sym}")
print(f"AMU T2* reg: {amu_t2s_reg}")
print(f"AMU T2* sym: {amu_t2s_sym}")
print(f"AMU T2* sym thr16: {amu_t2s_thr}")
print(f"Work dir (intermed.): {workdir}")

18 changes: 18 additions & 0 deletions scripts/register_AMU_to_PAM50.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# TODO: Download AMU-Poly-MNI zip files provided by Virginie Callot
cd $SCT_DIR/data/PAM50/template
# Create label on MNI-Poly-AMU to identify the z (top part of the visible template). The XY coordinate does not matter too much since pipeline will be followed by centermass registration.
# TODO: remove hardcoding of paths
sct_label_utils -i ~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz -create 75,75,965,1 -o ~/Desktop/MNI-POLY-AMU/label_AMU15.nii.gz
# Create associated label in the PAM50
sct_label_utils -i PAM50_t2.nii.gz -create 70,70,959,1 -o label_PAM50.nii.gz
# Register AMU15 --> PAM50
sct_register_multimodal -i ~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz -iseg ~/Desktop/MNI-POLY-AMU/AMU15_GW_sym.nii.gz -ilabel ~/Desktop/MNI-POLY-AMU/label_AMU15.nii.gz -d PAM50_t2.nii.gz -dseg PAM50_t2_seg.nii.gz -dlabel label_PAM50.nii.gz -param step=0,type=label,dof=Tx_Ty_Tz:step=1,type=seg,algo=slicereg,poly=2 -qc qc
# Apply the warping to the AMU15 template and its segmentation
sct_apply_transfo -i ~/Desktop/MNI-POLY-AMU/AMU15_T2star_sym.nii.gz -d PAM50_t2.nii.gz -w warp_AMU15_T2star_sym2PAM50_t2.nii.gz -x linear
sct_apply_transfo -i ~/Desktop/MNI-POLY-AMU/AMU15_G_sym.nii.gz -d PAM50_t2.nii.gz -w warp_AMU15_T2star_sym2PAM50_t2.nii.gz -x linear
# Symmetrize the warped objects
# TODO: remove hardcoding of paths
python3 ~/code/PAM50/scripts/symmetrize_cord_segmentation.py -i AMU15_G_sym_reg.nii.gz --dtype float32 --mode average
python3 ~/code/PAM50/scripts/symmetrize_cord_segmentation.py -i AMU15_T2star_sym_reg.nii.gz --dtype float32 --mode average
# Remove negatives values on the T2*-weighted image by thresholding, and output as UINT16 type (like for PAM50_t2)
sct_maths -i AMU15_T2star_sym_reg_sym.nii.gz -thr 0 -type uint16 -o AMU15_T2star_sym_reg_sym_thr.nii.gz
Loading