Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running NewSegment properly #258

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 223 additions & 25 deletions pypreprocess/nipype_preproc_spm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def _configure_backends(spm_dir=None, matlab_exec=None, spm_mcr=None,

# prepare template TPMs
tissue1 = ((os.path.join(SPM_DIR, tissue_path, 'TPM.nii'), 1),
2, (True, True), (False, False))
2, (True, True), (True, True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A short comment on these parameters would be welcome

tissue2 = ((os.path.join(SPM_DIR, tissue_path, 'TPM.nii'), 2),
2, (True, True), (False, False))
2, (True, True), (True, True))
tissue3 = ((os.path.join(SPM_DIR, tissue_path, 'TPM.nii'), 3),
2, (True, False), (False, False))
2, (True, True), (True, True))
tissue4 = ((os.path.join(SPM_DIR, tissue_path, 'TPM.nii'), 4),
3, (False, False), (False, False))
tissue5 = ((os.path.join(SPM_DIR, tissue_path, 'TPM.nii'), 5),
Expand Down Expand Up @@ -664,6 +664,144 @@ def _do_subject_segment(subject_data, output_modulated_tpms=True, spm_dir=None,
return subject_data.sanitize()


def _do_subject_newsegment(subject_data, output_modulated_tpms=True,
spm_dir=None, matlab_exec=None, spm_mcr=None,
normalize=False, caching=True, report=True,
software="spm", hardlink_output=True):
"""
Wrapper for running spm.NewSegment with optional reporting.

If subject_data has a `results_gallery` attribute, then QA thumbnails will
be commited after this node is executed

Parameters
-----------
subject_data: `SubjectData` object
subject data whose anatomical image (subject_data.anat) is to be
segmented

output_modulated_tpms: bool, optional (default False)
if set, then modulated TPMS will be produced (alongside unmodulated
TPMs); this can be useful for VBM

caching: bool, optional (default True)
if true, then caching will be enabled

normalize: bool, optional (default False)
flag indicating whether warped brain compartments (gm, wm, csf) are to
be generated (necessary if the caller wishes the brain later)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a flag related to spm.Segment. It doesn't hold for spm.NewSegment: to be removed.


report: bool, optional (default True)
flag controlling whether post-preprocessing reports should be generated

Returns
-------
subject_data: `SubjectData` object
preprocessed subject_data

New Attributes
==============
subject_data.nipype_results['segment']: Nipype output object
(raw) result of running spm.Segment

subject_data.gm: string
path to subject's segmented gray matter image in native space

subject_data.wm: string
path to subject's segmented white matter image in native space

subject_data.csf: string
path to subject's CSF image in native space

if normalize then the following additional data fiels are
populated:

subject_data.wgm: string
path to subject's segmented gray matter image in standard space

subject_data.wwm: string
path to subject's segmented white matter image in standard space

subject_data.wcsf: string
path to subject's CSF image in standard space


Notes
-----
Input subject_data is modified.

"""

# sanitize software choice
software = software.lower()
if software != "spm":
raise NotImplementedError("Only SPM is supported; got '%s'" % software)

# configure SPM back-end
_configure_backends(spm_dir=spm_dir, matlab_exec=matlab_exec,
spm_mcr=spm_mcr)
assert SPM_DIR is not None and os.path.isdir(SPM_DIR), (
"SPM_DIR '%s' doesn't exist; you need to export it!" % SPM_DIR)

# sanitize subject_data (do things like .nii.gz -> .nii conversion, etc.)
subject_data.sanitize(niigz2nii=(software == "spm"))

# prepare for smart caching
if caching:
cache_dir = os.path.join(subject_data.scratch, 'cache_dir')
if not os.path.exists(cache_dir):
os.makedirs(cache_dir)
segment = NipypeMemory(base_dir=cache_dir).cache(spm.NewSegment)
else:
segment = spm.NewSegment().run

# configure node
if not normalize:
gm_output_type = [False, False, True]
wm_output_type = [False, False, True]
csf_output_type = [False, False, True]
else:
gm_output_type = [output_modulated_tpms, True, True]
wm_output_type = [output_modulated_tpms, True, True]
csf_output_type = [output_modulated_tpms, True, True]
# run node
segment_result = segment(
channel_files=subject_data.anat,
write_deformation_fields=[True, True],
tissues=TISSUES,
ignore_exception=False
)

# failed node
subject_data.nipype_results['segment'] = segment_result
if segment_result.outputs is None:
subject_data.failed = True
return subject_data

# collect output
subject_data.parameter_file = segment_result.outputs.transformation_mat[0]
subject_data.deformation_file = segment_result.outputs.forward_deformation_field
subject_data.nipype_results['segment'] = segment_result

subject_data.gm = segment_result.outputs.native_class_images[0][0]
subject_data.wm = segment_result.outputs.native_class_images[1][0]
subject_data.csf = segment_result.outputs.native_class_images[2][0]
if normalize:
subject_data.mwgm = segment_result.outputs.modulated_class_images[0][0]
subject_data.mwwm = segment_result.outputs.modulated_class_images[1][0]
subject_data.mwcsf = segment_result.outputs.modulated_class_images[2][0]

# commit output files
if hardlink_output:
subject_data.hardlink_output_files()

# generate segmentation thumbs
if report:
subject_data.generate_segmentation_thumbnails()

return subject_data.sanitize()


def _do_subject_normalize(subject_data, fwhm=0., anat_fwhm=0., caching=True,
spm_dir=None, matlab_exec=None, spm_mcr=None,
func_write_voxel_sizes=[3, 3, 3],
Expand Down Expand Up @@ -730,12 +868,23 @@ def _do_subject_normalize(subject_data, fwhm=0., anat_fwhm=0., caching=True,
# sanitize subject_data (do things like .nii.gz -> .nii conversion, etc.)
subject_data.sanitize(niigz2nii=(software == "spm"))

# XXX get spm version
spm_version = _get_version_spm(SPM_DIR)

# prepare for smart caching
if caching:
cache_dir = os.path.join(subject_data.scratch, 'cache_dir')
if not os.path.exists(cache_dir): os.makedirs(cache_dir)
normalize = NipypeMemory(base_dir=cache_dir).cache(spm.Normalize)
else: normalize = spm.Normalize().run
if spm_version == 'spm8':
normalize = NipypeMemory(base_dir=cache_dir).cache(spm.Normalize)
elif spm_version == 'spm12':
normalize = NipypeMemory(base_dir=cache_dir).cache(spm.Normalize12)
else:
# XXX normalize or normalize12
if spm_version == 'spm8':
normalize = spm.Normalize().run
elif spm_version == 'spm12':
normalize = spm.Normalize12().run

segmented = 'segment' in subject_data.nipype_results

Expand All @@ -751,6 +900,8 @@ def _do_subject_normalize(subject_data, fwhm=0., anat_fwhm=0., caching=True,
else:
parameter_file = subject_data.nipype_results[
'segment'].outputs.transformation_mat
deformation_file = subject_data.nipype_results[
'segment'].outputs.forward_deformation_field

subject_data.parameter_file = parameter_file

Expand All @@ -770,14 +921,23 @@ def _do_subject_normalize(subject_data, fwhm=0., anat_fwhm=0., caching=True,
write_voxel_sizes = get_vox_dims(apply_to_files)
else: write_voxel_sizes = anat_write_voxel_sizes
apply_to_files = subject_data.anat

# run node
# XXX replace by normalize12
print('[DEFORMATION FILE]', deformation_file)
print('[APPLY TO FILE]', apply_to_files)
normalize_result = normalize(
parameter_file=parameter_file,
deformation_file=deformation_file,
apply_to_files=apply_to_files,
write_voxel_sizes=list(write_voxel_sizes),
# write_bounding_box=[[-78, -112, -50], [78, 76, 85]],
write_interp=1, jobtype='write', ignore_exception=True)
write_interp=1, jobtype='write', ignore_exception=False)
print('[RESULTS]', normalize_result.outputs)

# run node
# normalize_result = normalize(
# parameter_file=parameter_file,
# apply_to_files=apply_to_files,
# write_voxel_sizes=list(write_voxel_sizes),
# # write_bounding_box=[[-78, -112, -50], [78, 76, 85]],
# write_interp=1, jobtype='write', ignore_exception=False)

# failed node ?
if normalize_result.outputs is None:
Expand Down Expand Up @@ -1110,6 +1270,7 @@ def do_subject_preproc(
coreg_anat_to_func=False,
coregister_software="spm",
segment=True,
newsegment=True,
normalize=True,
dartel=False,
fwhm=0.,
Expand Down Expand Up @@ -1325,9 +1486,15 @@ def do_subject_preproc(
# segmentation of anatomical image
#####################################
if segment:
subject_data = _do_subject_segment(
subject_data, caching=caching, normalize=normalize, report=report,
hardlink_output=hardlink_output)
# XXX newsegment goes here
if newsegment:
subject_data = _do_subject_newsegment(
subject_data, caching=caching, normalize=normalize,
report=report, hardlink_output=hardlink_output)
else:
subject_data = _do_subject_segment(
subject_data, caching=caching, normalize=normalize,
report=report, hardlink_output=hardlink_output)

# handle failed node
if subject_data.failed:
Expand Down Expand Up @@ -1379,12 +1546,12 @@ def do_subject_preproc(
return subject_data.sanitize()


def _do_subjects_newsegment(
def _do_subjects_dartel(
subjects, output_dir, spm_dir=None, matlab_exec=None,
spm_mcr=None, fwhm=0, anat_fwhm=0., n_jobs=-1, report=True,
func_write_voxel_sizes=None, anat_write_voxel_sizes=None,
output_modulated_tpms=False, parent_results_gallery=None,
do_dartel=True, **kwargs):
**kwargs):
"""
Runs NewSegment + optionally Dartel and DartelNorm2MNI, on given subjects.

Expand All @@ -1401,6 +1568,8 @@ def _do_subjects_newsegment(
os.makedirs(cache_dir)
mem = NipypeMemory(base_dir=cache_dir)

# XXX put this in do_subject_newsegment
"""
# create node
newsegment = mem.cache(spm.NewSegment)

Expand All @@ -1425,18 +1594,39 @@ def _do_subjects_newsegment(
sd.generate_segmentation_thumbnails()
if not do_dartel:
return subjects
"""
# TODO check if newsegment was done properly
# TODO build newsegment_result properly
# for sd in subjects:
# print(sd.nipype_results['segment'].outputs.dartel_input_images)
# dartel_inputs = [
# sd.nipype_results['segment'].outputs.dartel_input_images
# for sd in subjects]

dartel_input_images = []
for i in range(6):
tpm_subjects = []
for sd in subjects:
tpms = sd.nipype_results['segment'].outputs.dartel_input_images[i]
if tpms:
tpm_subjects.extend(tpms)
if tpm_subjects:
dartel_input_images.append(tpm_subjects)

print(dartel_input_images)

# compute DARTEL template for group data
dartel = mem.cache(spm.DARTEL)
dartel_input_images = [
tpms for tpms in newsegment_result.outputs.dartel_input_images if tpms]
# dartel_input_images = [
# tpms for tpms in newsegment_result.outputs.dartel_input_images if tpms]
# dartel_input_images = [tpms for tpms in dartel_inputs if tpms]
dartel_result = dartel(image_files=dartel_input_images)
if dartel_result.outputs is None:
return

for j, subject_data in enumerate(subjects):
subject_data.gm = newsegment_result.outputs.dartel_input_images[0][j]
subject_data.wm = newsegment_result.outputs.dartel_input_images[1][j]
# subject_data.gm = newsegment_result.outputs.dartel_input_images[0][j]
# subject_data.wm = newsegment_result.outputs.dartel_input_images[1][j]
subject_data.dartel_flow_fields = dartel_result.outputs\
.dartel_flow_fields[j]

Expand Down Expand Up @@ -1571,8 +1761,9 @@ def do_subjects_preproc(subject_factory, session_ids=None, **preproc_params):

# DARTEL or NewSegment on 1 subject is senseless
if len(subjects) < 2:
if newsegment:
warnings.warn("There is only one subject. Disabling NewSegment.")
# TODO remove this
# if newsegment:
# warnings.warn("There is only one subject. Disabling NewSegment.")
if dartel:
warnings.warn("There is only one subject. Disabling DARTEL.")
dartel = False
Expand Down Expand Up @@ -1699,9 +1890,10 @@ def finalize_report():

normalize = preproc_params.get("normalize", True)

# XXX should we check this ?
# don't yet segment nor normalize if dartel enabled
if newsegment:
for stage in ["segment", "normalize", "last_stage"]:
if dartel:
for stage in ["normalize", "last_stage"]:
preproc_params[stage] = False

# postpone smoothing
Expand All @@ -1716,8 +1908,14 @@ def finalize_report():

# run DARTEL
preproc_params.update(backup_params)
if newsegment:
subjects = _do_subjects_newsegment(
# XXX replace by do_subjects_dartel
# if newsegment:
# subjects = _do_subjects_newsegment(
# subjects, scratch, n_jobs=n_jobs, do_dartel=dartel,
# **preproc_params)

if dartel:
subjects = _do_subjects_dartel(
subjects, scratch, n_jobs=n_jobs, do_dartel=dartel,
**preproc_params)

Expand Down