-
Notifications
You must be signed in to change notification settings - Fork 14
/
lesion_generate_predictions.py
155 lines (109 loc) · 5.82 KB
/
lesion_generate_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os.path
import time
import argparse
import fnmatch
import string
import datetime
import numpy as np
import tensorflow as tf
import nibabel
import scipy.ndimage
import utils.lesion_preprocessing as preprocessing
import utils.measurements as measurements
import architecture.networks as networks
# Command Line Arguments
parser = argparse.ArgumentParser()
parser.add_argument('--model', '-m', help='the model file.')
parser.add_argument('--data_directory', '-d', help='the directory which contains the validation data set.')
parser.add_argument('--name', '-n', help='the name of the experiment', default=None)
parser.add_argument('--out_postfix', '-o', help='this postfix will be added to all files.', default="_prediction")
parser.add_argument('--no_batch_norm', help='set if you want to load a model without batch normalization', action='store_true')
parser.add_argument('--no_crelu', help='set if you want to load a model with standard ReLUs', action='store_true')
parser.add_argument('--unet', help='set if you want to load the standard U-Net model', action='store_true')
args = parser.parse_args()
# Training Parameters
model = args.model
data_dir = args.data_directory
data_dir = os.path.join(data_dir, '')
batch_size = 1
in_channels = 5
edge_radius = (in_channels - 1) / 2
tf.reset_default_graph()
graph = tf.get_default_graph()
batch_norm = not args.no_batch_norm
activation_function = "ReLU" if args.no_crelu else "cReLU"
if args.unet:
print "Setting up the standard U-net architecture ..."
else:
print "Setting up the architecture with {} and batch normalization {} ...".format(activation_function, "enabled" if batch_norm else "disabled")
if not args.unet:
start_filters = 90 if args.no_crelu else 64
tf_inputs, tf_logits, _, tf_keep_prob, tf_training = networks.parameter_efficient(in_channels=in_channels, out_channels=2, start_filters=start_filters, input_side_length=256, sparse_labels=True, batch_size=batch_size, activation=activation_function, batch_norm=batch_norm)
else:
tf_inputs, tf_logits, _, tf_keep_prob = networks.unet(in_channels=in_channels, out_channels=2, start_filters=64, input_side_length=256, sparse_labels=True, batch_size=batch_size, padded_convolutions=True)
tf_prediction = tf.to_int32(tf.argmax(tf_logits, 3, name='prediction'))
np_inputs = np.zeros([batch_size, 256, 256, in_channels], dtype=np.float32)
saver = tf.train.Saver(max_to_keep=100, keep_checkpoint_every_n_hours=1)
print "Loading pre-processing pipeline"
validation_pipeline = preprocessing.generate_predictions(data_dir, save_name=args.name)
with tf.Session() as sess:
print "Loading model {}".format(model)
saver.restore(sess, model)
print "Starting volume generation"
slice_counter = edge_radius
name = ""
out_volume = None
header = None
for inputs, parameters in validation_pipeline:
# Check if start of new volume
if parameters["file_names"][1] != name:
# If new volume is not the first, save the volume that came before
if out_volume is not None:
if slice_counter != out_volume.shape[-1] - edge_radius:
raise RuntimeError("slice_counter: {}, volume.shape: {}, for volume {}".format(slice_counter, out_volume.shape, name))
# Create nibabel volume
img = nibabel.Nifti1Image(out_volume, header.get_base_affine(), header=header)
img.set_data_dtype(np.uint8)
name_parts = os.path.splitext(os.path.basename(name))
nibabel.save(img, os.path.join(data_dir, name_parts[0] + args.out_postfix + name_parts[1]))
# Parameters of new volume
name = parameters["file_names"][1]
header = parameters["nifti_header"]
dimensions = header.get_data_shape()[:3]
# Reset variables
slice_counter = edge_radius
out_volume = np.zeros(dimensions, dtype=np.uint8)
# No need to predict lesions if there is no liver on the slice
if 1 in inputs[1]:
# Prepare network input
np_inputs[0, :, :, :] = inputs[0]
feed_dict = {
tf_inputs: np_inputs,
tf_keep_prob: 1.0
}
if not args.unet:
feed_dict[tf_training] = False
# Run network and obtain prediction
np_prediction = sess.run(tf_prediction, feed_dict=feed_dict)
# Re-orient and re-size prediction for liver crop to fit original volume
np_prediction = np.transpose(np_prediction[0, :, :])
crop_indices = parameters["crop_indices"]
side_lengths = parameters["crop_canvas_size"]
image_indices = parameters["image_indices"]
zooms = np.asarray(side_lengths, dtype=np.float) / np.asarray([256., 256.], dtype=np.float)
np_prediction = scipy.ndimage.zoom(np_prediction, zooms, order=0)
# Lesions have class label 2
np_prediction = np.round(np_prediction).astype(np.uint8) * 2
# Write prediction for liver crop to volume with the original size
out_volume[image_indices[0]:image_indices[1], image_indices[2]:image_indices[3], slice_counter] = np_prediction[crop_indices[0]:crop_indices[1], crop_indices[2]:crop_indices[3]]
slice_counter += 1
if slice_counter != out_volume.shape[-1] - edge_radius:
raise RuntimeError("slice_counter: {}, volume.shape: {}, for last volume: {}".format(slice_counter, out_volume.shape, name))
# Create nibabel volume for last volume
img = nibabel.Nifti1Image(out_volume, header.get_base_affine(), header=header)
img.set_data_dtype(np.uint8)
name_parts = os.path.splitext(os.path.basename(name))
nibabel.save(img, os.path.join(data_dir, name_parts[0] + args.out_postfix + name_parts[1]))
# After Volume Creation
validation_pipeline.close()
print "Done!"