From 3a667e13a7dd51e95cdcf4125ef5fbf546327beb Mon Sep 17 00:00:00 2001 From: Jordan G Date: Wed, 17 May 2017 17:59:09 +0200 Subject: [PATCH 1/3] ADD : add scope for tensorboard + parameters to fix output image size --- GenerateUNetModel.py | 82 ++++++++++++++++++ PredictDirectoryUNet.py | 37 ++++++++ tf_unet/image_util.py | 22 ++++- tf_unet/layers.py | 87 +++++++++++-------- tf_unet/unet.py | 187 +++++++++++++++++++++++----------------- 5 files changed, 302 insertions(+), 113 deletions(-) create mode 100644 GenerateUNetModel.py create mode 100644 PredictDirectoryUNet.py diff --git a/GenerateUNetModel.py b/GenerateUNetModel.py new file mode 100644 index 0000000..cfb7cf4 --- /dev/null +++ b/GenerateUNetModel.py @@ -0,0 +1,82 @@ +import sys, getopt, os +import numpy as np +import shutil +from tf_unet import image_util, unet + + +""" +Original paper : +https://lmb.informatik.uni-freiburg.de/Publications/2015/RFB15a/ + Parameters['layers'] = 5 + Parameters['convolutionFilter'] = 3 +""" + + +#Change the following field to configure your UNet +def devTest(): + Parameters = dict() + Parameters['layers'] = 3 + Parameters['convolutionFilter'] = 7 + Parameters['depthConvolutionFilter'] = 16 + Parameters['channel'] = 3 + Parameters['outputClassNumber'] = 2 + Parameters['regulationCoefficient'] = 0.01 + Parameters['optimizerType'] = "momentum" + Parameters['optimizerValue'] = 0.90 + Parameters['learningRate'] = 0.01 + Parameters['batchSize'] = 2 + Parameters['decayRate'] = 0.95 + Parameters['dropout'] = 0.8 + Parameters['epoch'] = 100 + Parameters['iterations'] = 10 + Parameters['datasrc'] = "C:/Work/Projets/Provital - Vergeture/Images apprentissage/test2" + Parameters['output'] = "C:/Work/Projets/Provital - Vergeture/Images apprentissage/Model" + return Parameters + +if __name__ == "__main__": + + Parameters = devTest() + +# !!! Remove current Parameters[output] directory to create a new one !!! +if (os.path.isdir(Parameters['output'] + '/model')): + shutil.rmtree(Parameters['output']) +os.mkdir(Parameters['output']) +os.mkdir(Parameters['output'] + '/model') +os.mkdir(Parameters['output'] + '/trainPrediction') +os.mkdir(Parameters['output'] + '/testPrediction') + +#Load images to train UNet +data_provider = image_util.ImageDataProvider(Parameters['datasrc'] + '/*', data_suffix=".png", mask_suffix="_mask.png") + +# Compute the theoric total number of convolution filters : +# ConvFilter in descending path + ConvFilter in expanding path + output convolution (1x1) + up-convolution (2x2 +totConvFilter = (Parameters['layers']*2) + (Parameters['layers']-1)*2 + 1 + (Parameters['layers'] - 1) +print("Total number of convolution filters (attempted) : " + str(totConvFilter)) + +#Compute the number of iterations necessary to see all train dataset in one batch +if int(Parameters['iterations']) == 0: + Parameters['iterations'] = round(len(data_provider.data_files) / int(Parameters['batchSize'])) + +net = unet.Unet(layers=int(Parameters['layers']), + features_root=int(Parameters['depthConvolutionFilter']), + channels=int(Parameters['channel']), + regularisationConstant=np.float32(Parameters['regulationCoefficient']), + n_class=int(Parameters['outputClassNumber']), + filter_size=int(Parameters['convolutionFilter'])) + +print("\tUNet created") +trainer = unet.Trainer(net, + optimizer=str(Parameters['optimizerType']), + batch_size=int(Parameters['batchSize']), + opt_kwargs=dict(momentum=np.float32(Parameters['optimizerValue']), + learning_rate=np.float32(Parameters['learningRate']), + decay_rate=np.float32(Parameters['decayRate']))) +print("\tUNet initialized") +path = trainer.train(data_provider, + str(Parameters['output']), + str(Parameters['output'] + '/trainPrediction'), + dropout=np.float32(Parameters['dropout']), + training_iters=int(Parameters['iterations']), + epochs=int(Parameters['epoch'])) + +print("\n\tEND of learning step !") \ No newline at end of file diff --git a/PredictDirectoryUNet.py b/PredictDirectoryUNet.py new file mode 100644 index 0000000..0ed864f --- /dev/null +++ b/PredictDirectoryUNet.py @@ -0,0 +1,37 @@ +from tf_unet import util, image_util, unet +import sys +import glob +import os + +if __name__ == "__main__": + + #CHANGE YO YOUR PATH + listOfFile = glob.glob("Path/To/ImageToPredict/*.png") + modelPath = "Path/To/YourModel/model.ckpt" + + #GIVE YOUR CONFIGURATION USE TO GENERATE CURRENT model.ckpt + Layers = 3 + Features = 16 + Channels = 3 + regCoef = 0.01 + nClass = 2 + filterSize = 7 + + print('\n') + print('Path to model : ' + modelPath) + print('Layers = ' + Layers) + print('Features = ' + Features) + print('Channels = ' + Channels) + print('Regulation Coef = ' + regCoef) + print('Class number = ' + nClass) + print('Convolution size filter = ' + filterSize) + + net = unet.Unet(layers=int(Layers), features_root=int(Features), channels=int(Channels), regularisationConstant=float(regCoef), n_class=int(nClass), filter_size=int(filterSize)) + + for i in range(0, len(listOfFile)): + imagePath = listOfFile[i] + data_provider = image_util.SingleImageDataProvider(imagePath) + predicter = net.predict(modelPath, data_provider.img) + util.PlotSingleImagePrediction(data_provider, predicter, output_path=imagePath) + + print("\nPrediction finished !") \ No newline at end of file diff --git a/tf_unet/image_util.py b/tf_unet/image_util.py index 8c91fd4..fc1e8d2 100644 --- a/tf_unet/image_util.py +++ b/tf_unet/image_util.py @@ -158,4 +158,24 @@ def _next_data(self): img = self._load_file(image_name, np.float32) label = self._load_file(label_name, np.bool) - return img,label + return img, label + + +class SingleImageDataProvider(ImageDataProvider): + + def __init__(self, singleImagePath): + + super(ImageDataProvider, self).__init__(None, None) + + self.data_files = list() + self.data_files.append(singleImagePath) + self.data_suffix = '.' + str(singleImagePath.split('/')[-1].split('.')[-1]) + self.imgShape = self._load_file(self.data_files[0]).shape + self.channels = self.imgShape[-1] + self.file_idx = -1 + self.img = np.zeros((1, self.imgShape[0], self.imgShape[1], self.channels)) + + data = self._load_file(self.data_files[0], np.float32) + data = self._process_data(data) + + self.img[0] = data \ No newline at end of file diff --git a/tf_unet/layers.py b/tf_unet/layers.py index f321511..46cd4a7 100644 --- a/tf_unet/layers.py +++ b/tf_unet/layers.py @@ -22,50 +22,67 @@ import tensorflow as tf def weight_variable(shape, stddev=0.1): - initial = tf.truncated_normal(shape, stddev=stddev) - return tf.Variable(initial) + with tf.name_scope('weight_variable') as scope: + initial = tf.truncated_normal(shape, stddev=stddev) + return tf.Variable(initial) def weight_variable_devonc(shape, stddev=0.1): - return tf.Variable(tf.truncated_normal(shape, stddev=stddev)) + with tf.name_scope('weight_variable_devonc') as scope: + return tf.Variable(tf.truncated_normal(shape, stddev=stddev)) def bias_variable(shape): - initial = tf.constant(0.1, shape=shape) - return tf.Variable(initial) - -def conv2d(x, W,keep_prob_): - conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID') - return tf.nn.dropout(conv_2d, keep_prob_) - -def deconv2d(x, W,stride): - x_shape = tf.shape(x) - output_shape = tf.stack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2]) - return tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding='VALID') - -def max_pool(x,n): - return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='VALID') - -def crop_and_concat(x1,x2): - x1_shape = tf.shape(x1) - x2_shape = tf.shape(x2) - # offsets for the top left corner of the crop - offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, 0] - size = [-1, x2_shape[1], x2_shape[2], -1] - x1_crop = tf.slice(x1, offsets, size) - return tf.concat([x1_crop, x2], 3) + with tf.name_scope('bias_variable') as scope: + initial = tf.constant(0.1, shape=shape) + return tf.Variable(initial) + + +def conv2d(x, W, keep_prob_): + with tf.name_scope('convolution') as scope: + # VALID = without padding + # SAME = padding with 0 + conv_2d = tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='VALID', name="LAST_CONVOLUTION") + return tf.nn.dropout(conv_2d, keep_prob_) + + +def deconv2d(x, W, stride): + with tf.name_scope('deconvolution') as scope: + x_shape = tf.shape(x) + output_shape = tf.stack([x_shape[0], x_shape[1] * 2, x_shape[2] * 2, x_shape[3] // 2]) + #output_shape = tf.pack([x_shape[0], x_shape[1]*2, x_shape[2]*2, x_shape[3]//2]) + return tf.nn.conv2d_transpose(x, W, output_shape, strides=[1, stride, stride, 1], padding='VALID') + + +def max_pool(x, n): + with tf.name_scope('max_pooling') as scope: + return tf.nn.max_pool(x, ksize=[1, n, n, 1], strides=[1, n, n, 1], padding='VALID') + + +def crop_and_concat(x1, x2): + with tf.name_scope('crop_and_concatenate') as scope: + x1_shape = tf.shape(x1) + x2_shape = tf.shape(x2) + # offsets for the top left corner of the crop + offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, (x1_shape[2] - x2_shape[2]) // 2, 0] + size = [-1, x2_shape[1], x2_shape[2], -1] + x1_crop = tf.slice(x1, offsets, size) + return tf.concat(3, [x1_crop, x2]) def pixel_wise_softmax(output_map): - exponential_map = tf.exp(output_map) - evidence = tf.add(exponential_map,tf.reverse(exponential_map,[False,False,False,True])) - return tf.div(exponential_map,evidence, name="pixel_wise_softmax") + with tf.name_scope('pixel_wise_softmax') as scope: + exponential_map = tf.exp(output_map) + evidence = tf.add(exponential_map, tf.reverse(exponential_map,[False,False,False,True])) + return tf.div(exponential_map, evidence, name="pixel_wise_softmax") def pixel_wise_softmax_2(output_map): - exponential_map = tf.exp(output_map) - sum_exp = tf.reduce_sum(exponential_map, 3, keep_dims=True) - tensor_sum_exp = tf.tile(sum_exp, tf.stack([1, 1, 1, tf.shape(output_map)[3]])) - return tf.div(exponential_map,tensor_sum_exp) + with tf.name_scope('pixel_wise_softmax_2') as scope: + exponential_map = tf.exp(output_map) + sum_exp = tf.reduce_sum(exponential_map, 3, keep_dims=True) + tensor_sum_exp = tf.tile(sum_exp, tf.pack([1, 1, 1, tf.shape(output_map)[3]])) + return tf.div(exponential_map,tensor_sum_exp) def cross_entropy(y_,output_map): - return -tf.reduce_mean(y_*tf.log(tf.clip_by_value(output_map,1e-10,1.0)), name="cross_entropy") -# return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(output_map), reduction_indices=[1])) + with tf.name_scope('cross_entropy') as scope: + return -tf.reduce_mean(y_*tf.log(tf.clip_by_value(output_map,1e-10,1.0)), name="cross_entropy") + # return tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(output_map), reduction_indices=[1])) diff --git a/tf_unet/unet.py b/tf_unet/unet.py index 2b3b24f..f92c093 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -30,7 +30,7 @@ from tf_unet import util from tf_unet.layers import (weight_variable, weight_variable_devonc, bias_variable, conv2d, deconv2d, max_pool, crop_and_concat, pixel_wise_softmax_2, - cross_entropy) + pixel_wise_softmax,cross_entropy) logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s') @@ -70,72 +70,97 @@ def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16, in_size = 1000 size = in_size - # down layers - for layer in range(0, layers): - features = 2**layer*features_root - stddev = np.sqrt(2 / (filter_size**2 * features)) - if layer == 0: - w1 = weight_variable([filter_size, filter_size, channels, features], stddev) - else: - w1 = weight_variable([filter_size, filter_size, features//2, features], stddev) - - w2 = weight_variable([filter_size, filter_size, features, features], stddev) - b1 = bias_variable([features]) - b2 = bias_variable([features]) - - conv1 = conv2d(in_node, w1, keep_prob) - tmp_h_conv = tf.nn.relu(conv1 + b1) - conv2 = conv2d(tmp_h_conv, w2, keep_prob) - dw_h_convs[layer] = tf.nn.relu(conv2 + b2) - - weights.append((w1, w2)) - biases.append((b1, b2)) - convs.append((conv1, conv2)) - - size -= 4 - if layer < layers-1: - pools[layer] = max_pool(dw_h_convs[layer], pool_size) - in_node = pools[layer] - size /= 2 - - in_node = dw_h_convs[layers-1] - - # up layers - for layer in range(layers-2, -1, -1): - features = 2**(layer+1)*features_root - stddev = np.sqrt(2 / (filter_size**2 * features)) - - wd = weight_variable_devonc([pool_size, pool_size, features//2, features], stddev) - bd = bias_variable([features//2]) - h_deconv = tf.nn.relu(deconv2d(in_node, wd, pool_size) + bd) - h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv) - deconv[layer] = h_deconv_concat - - w1 = weight_variable([filter_size, filter_size, features, features//2], stddev) - w2 = weight_variable([filter_size, filter_size, features//2, features//2], stddev) - b1 = bias_variable([features//2]) - b2 = bias_variable([features//2]) - - conv1 = conv2d(h_deconv_concat, w1, keep_prob) - h_conv = tf.nn.relu(conv1 + b1) - conv2 = conv2d(h_conv, w2, keep_prob) - in_node = tf.nn.relu(conv2 + b2) - up_h_convs[layer] = in_node - - weights.append((w1, w2)) - biases.append((b1, b2)) - convs.append((conv1, conv2)) - - size *= 2 - size -= 4 - - # Output Map - weight = weight_variable([1, 1, features_root, n_class], stddev) - bias = bias_variable([n_class]) - conv = conv2d(in_node, weight, tf.constant(1.0)) - output_map = tf.nn.relu(conv + bias) - up_h_convs["out"] = output_map - + + # find the number of pixels to suppress in one edge according to size of convolution filter + pixelEdge = (filter_size-1)/2 + # compute the total number of pixel to remove at each time we use convolution filter + step = pixelEdge * 2 + # Number of convolution in one layers (it is not a parameters) + nLayerofConvolution = 2 + # Variable use to count the number of convolution and max-pooling into architecture + nConvFilter = 0 + nMaxPooling = 0 + with tf.name_scope('DOWN_LAYER') as scope: + # down layers + for layer in range(0, layers): + features = 2 ** layer * features_root + stddev = np.sqrt(2 / (filter_size ** 2 * features)) + if layer == 0: + w1 = weight_variable([filter_size, filter_size, channels, features], stddev) + else: + w1 = weight_variable([filter_size, filter_size, features // 2, features], stddev) + + w2 = weight_variable([filter_size, filter_size, features, features], stddev) + b1 = bias_variable([features]) + b2 = bias_variable([features]) + + conv1 = conv2d(in_node, w1, keep_prob) + tmp_h_conv = tf.nn.relu(conv1 + b1) + conv2 = conv2d(tmp_h_conv, w2, keep_prob) + dw_h_convs[layer] = tf.nn.relu(conv2 + b2) + + nConvFilter += 2 + + weights.append((w1, w2)) + biases.append((b1, b2)) + convs.append((conv1, conv2)) + + size = math.floor(size - (step * nLayerofConvolution)) + + if layer < layers - 1: + pools[layer] = max_pool(dw_h_convs[layer], pool_size) + in_node = pools[layer] + size = math.floor(size / pool_size) + nMaxPooling += 1 + + in_node = dw_h_convs[layers - 1] + + with tf.name_scope('UP_LAYERS') as scope: + # up layers + for layer in range(layers - 2, -1, -1): + features = 2 ** (layer + 1) * features_root + stddev = np.sqrt(2 / (filter_size ** 2 * features)) + + wd = weight_variable_devonc([pool_size, pool_size, features // 2, features], stddev) + bd = bias_variable([features // 2]) + h_deconv = tf.nn.relu(deconv2d(in_node, wd, pool_size) + bd) #up-conv 2x2 + h_deconv_concat = crop_and_concat(dw_h_convs[layer], h_deconv) #copy and crop + deconv[layer] = h_deconv_concat + + size = math.floor(size * pool_size) + + nMaxPooling += 1 + nConvFilter += 1 + + w1 = weight_variable([filter_size, filter_size, features, features // 2], stddev) + w2 = weight_variable([filter_size, filter_size, features // 2, features // 2], stddev) + b1 = bias_variable([features // 2]) + b2 = bias_variable([features // 2]) + + conv1 = conv2d(h_deconv_concat, w1, keep_prob) #conv + h_conv = tf.nn.relu(conv1 + b1) #relu + conv2 = conv2d(h_conv, w2, keep_prob) #conv + in_node = tf.nn.relu(conv2 + b2) #relu + up_h_convs[layer] = in_node + + nConvFilter += 2 + + weights.append((w1, w2)) + biases.append((b1, b2)) + convs.append((conv1, conv2)) + + size = math.floor(size - (step * nLayerofConvolution)) + + with tf.name_scope('OUTPUT') as scope: + # Output Map + weight = weight_variable([1, 1, features_root, n_class], stddev) + bias = bias_variable([n_class]) + conv = conv2d(in_node, weight, tf.constant(1.0)) #Last convolution 1x1 + output_map = tf.nn.relu(conv + bias) + up_h_convs["out"] = output_map + nMaxPooling += 1 + nConvFilter += 1 + if summaries: for i, (c1, c2) in enumerate(convs): tf.summary.image('summary_conv_%02d_01'%i, get_image_summary(c1)) @@ -162,7 +187,12 @@ def create_conv_net(x, keep_prob, channels, n_class, layers=3, features_root=16, variables.append(b1) variables.append(b2) - + print("\n\tpixelEdge = " + str(pixelEdge)) + print("\tOffset pixels with input image = " + str(in_size - size) + " pixels") + print("\tNumber of pixel to suppress : " + str((in_size - size)/4) + " by edge") + print('\tTotal Number of Convolution Filters : ' + str(nConvFilter)) + print('\tNumber of max-pooling & deconvolution : ' + str(nMaxPooling)) + return output_map, variables, int(in_size - size) @@ -182,19 +212,21 @@ def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, self.n_class = n_class self.summaries = kwargs.get("summaries", True) + with tf.name_scope('inputs') as scope: self.x = tf.placeholder("float", shape=[None, None, None, channels]) self.y = tf.placeholder("float", shape=[None, None, None, n_class]) self.keep_prob = tf.placeholder(tf.float32) #dropout (keep probability) + with tf.name_scope('CONVNET') as scope: logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs) - + with tf.name_scope('COST') as scope: self.cost = self._get_cost(logits, cost, cost_kwargs) - + with tf.name_scope('GRADIENT') as scope: self.gradients_node = tf.gradients(self.cost, self.variables) - + with tf.name_scope('CROSSENTROPY') as scope: self.cross_entropy = tf.reduce_mean(cross_entropy(tf.reshape(self.y, [-1, n_class]), tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class]))) - + with tf.name_scope('PREDICTER') as scope: self.predicter = pixel_wise_softmax_2(logits) self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3)) self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) @@ -310,9 +342,9 @@ def __init__(self, net, batch_size=1, optimizer="momentum", opt_kwargs={}): def _get_optimizer(self, training_iters, global_step): if self.optimizer == "momentum": - learning_rate = self.opt_kwargs.pop("learning_rate", 0.2) - decay_rate = self.opt_kwargs.pop("decay_rate", 0.95) - momentum = self.opt_kwargs.pop("momentum", 0.2) + learning_rate = self.opt_kwargs.pop('learning_rate') + decay_rate = self.opt_kwargs.pop('decay_rate') + momentum = self.opt_kwargs.pop('momentum') self.learning_rate_node = tf.train.exponential_decay(learning_rate=learning_rate, global_step=global_step, @@ -423,8 +455,9 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo self.norm_gradients_node.assign(norm_gradients).eval() if step % display_step == 0: - self.output_minibatch_stats(sess, summary_writer, step, batch_x, util.crop_to_shape(batch_y, pred_shape)) - + self.output_minibatch_stats(sess, summary_writer, step, batch_x, + util.crop_to_shape(batch_y, pred_shape)) + total_loss += loss self.output_epoch_stats(epoch, total_loss, training_iters, lr) From ae9593ac733acbe3a7945c1b467110223ac2bf31 Mon Sep 17 00:00:00 2001 From: Jordan G Date: Tue, 23 May 2017 10:21:33 +0200 Subject: [PATCH 2/3] ADD : function to save single images (not overload memory GPU). ENH : change format of output images to tif. (because prediction is in float) ADD : two new scripts for generate and predict model ENH : save model and image from training in single directory. --- GenerateUNetModel.py | 18 +++++------ PredictDirectoryUNet.py | 67 +++++++++++++++++++++-------------------- tf_unet/unet.py | 58 ++++++++++++++++++----------------- tf_unet/util.py | 14 ++++++++- 4 files changed, 86 insertions(+), 71 deletions(-) diff --git a/GenerateUNetModel.py b/GenerateUNetModel.py index cfb7cf4..3d98b3d 100644 --- a/GenerateUNetModel.py +++ b/GenerateUNetModel.py @@ -20,30 +20,28 @@ def devTest(): Parameters['depthConvolutionFilter'] = 16 Parameters['channel'] = 3 Parameters['outputClassNumber'] = 2 - Parameters['regulationCoefficient'] = 0.01 Parameters['optimizerType'] = "momentum" Parameters['optimizerValue'] = 0.90 Parameters['learningRate'] = 0.01 Parameters['batchSize'] = 2 Parameters['decayRate'] = 0.95 Parameters['dropout'] = 0.8 - Parameters['epoch'] = 100 - Parameters['iterations'] = 10 - Parameters['datasrc'] = "C:/Work/Projets/Provital - Vergeture/Images apprentissage/test2" - Parameters['output'] = "C:/Work/Projets/Provital - Vergeture/Images apprentissage/Model" + Parameters['epoch'] = 1 + Parameters['iterations'] = 0 + Parameters['datasrc'] = "C:/Work/Projets/Git - UNets/Data" + Parameters['output'] = "C:/Work/Projets/Git - UNets/Model" return Parameters if __name__ == "__main__": Parameters = devTest() -# !!! Remove current Parameters[output] directory to create a new one !!! +# !!! Remove current path in Parameters[output] directory to create a new one !!! if (os.path.isdir(Parameters['output'] + '/model')): shutil.rmtree(Parameters['output']) os.mkdir(Parameters['output']) os.mkdir(Parameters['output'] + '/model') os.mkdir(Parameters['output'] + '/trainPrediction') -os.mkdir(Parameters['output'] + '/testPrediction') #Load images to train UNet data_provider = image_util.ImageDataProvider(Parameters['datasrc'] + '/*', data_suffix=".png", mask_suffix="_mask.png") @@ -60,7 +58,6 @@ def devTest(): net = unet.Unet(layers=int(Parameters['layers']), features_root=int(Parameters['depthConvolutionFilter']), channels=int(Parameters['channel']), - regularisationConstant=np.float32(Parameters['regulationCoefficient']), n_class=int(Parameters['outputClassNumber']), filter_size=int(Parameters['convolutionFilter'])) @@ -72,9 +69,8 @@ def devTest(): learning_rate=np.float32(Parameters['learningRate']), decay_rate=np.float32(Parameters['decayRate']))) print("\tUNet initialized") -path = trainer.train(data_provider, - str(Parameters['output']), - str(Parameters['output'] + '/trainPrediction'), +path = trainer.train(data_provider=data_provider, + output_path=str(Parameters['output']), dropout=np.float32(Parameters['dropout']), training_iters=int(Parameters['iterations']), epochs=int(Parameters['epoch'])) diff --git a/PredictDirectoryUNet.py b/PredictDirectoryUNet.py index 0ed864f..0c25944 100644 --- a/PredictDirectoryUNet.py +++ b/PredictDirectoryUNet.py @@ -3,35 +3,38 @@ import glob import os -if __name__ == "__main__": - - #CHANGE YO YOUR PATH - listOfFile = glob.glob("Path/To/ImageToPredict/*.png") - modelPath = "Path/To/YourModel/model.ckpt" - - #GIVE YOUR CONFIGURATION USE TO GENERATE CURRENT model.ckpt - Layers = 3 - Features = 16 - Channels = 3 - regCoef = 0.01 - nClass = 2 - filterSize = 7 - - print('\n') - print('Path to model : ' + modelPath) - print('Layers = ' + Layers) - print('Features = ' + Features) - print('Channels = ' + Channels) - print('Regulation Coef = ' + regCoef) - print('Class number = ' + nClass) - print('Convolution size filter = ' + filterSize) - - net = unet.Unet(layers=int(Layers), features_root=int(Features), channels=int(Channels), regularisationConstant=float(regCoef), n_class=int(nClass), filter_size=int(filterSize)) - - for i in range(0, len(listOfFile)): - imagePath = listOfFile[i] - data_provider = image_util.SingleImageDataProvider(imagePath) - predicter = net.predict(modelPath, data_provider.img) - util.PlotSingleImagePrediction(data_provider, predicter, output_path=imagePath) - - print("\nPrediction finished !") \ No newline at end of file +#Script which predict a flow of images. Can be used with GPU Tensorflow without overload graphic memory + + +#CHANGE YO YOUR PATH +listOfFile = glob.glob("C:/Work/Projets/Git - UNets/DataTest/*.png") +modelPath = "C:/Work/Projets/Git - UNets/Model/model/model.ckpt" + +#GIVE YOUR CONFIGURATION USE TO GENERATE CURRENT model.ckpt +Layers = 3 +Features = 16 +Channels = 3 +nClass = 2 +filterSize = 7 + +print('\n') +print('Path to model : ' + modelPath) +print('Layers = ' + str(Layers)) +print('Features = ' + str(Features)) +print('Channels = ' + str(Channels)) +print('Class number = ' + str(nClass)) +print('Convolution size filter = ' + str(filterSize)) + +net = unet.Unet(layers=int(Layers), + features_root=int(Features), + channels=int(Channels), + n_class=int(nClass), + filter_size=int(filterSize)) + +for i in range(0, len(listOfFile)): + imagePath = listOfFile[i] + data_provider = image_util.SingleImageDataProvider(imagePath) + predicter = net.predict(modelPath, data_provider.img) + util.PlotSingleImagePrediction(predicter, output_path=imagePath) + +print("\nPrediction finished !") \ No newline at end of file diff --git a/tf_unet/unet.py b/tf_unet/unet.py index f1ced56..9222047 100644 --- a/tf_unet/unet.py +++ b/tf_unet/unet.py @@ -24,6 +24,7 @@ import numpy as np from collections import OrderedDict import logging +import math import tensorflow as tf @@ -212,24 +213,24 @@ def __init__(self, channels=3, n_class=2, cost="cross_entropy", cost_kwargs={}, self.n_class = n_class self.summaries = kwargs.get("summaries", True) - with tf.name_scope('inputs') as scope: - self.x = tf.placeholder("float", shape=[None, None, None, channels]) - self.y = tf.placeholder("float", shape=[None, None, None, n_class]) - self.keep_prob = tf.placeholder(tf.float32) #dropout (keep probability) - - with tf.name_scope('CONVNET') as scope: - logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs) - with tf.name_scope('COST') as scope: - self.cost = self._get_cost(logits, cost, cost_kwargs) - with tf.name_scope('GRADIENT') as scope: - self.gradients_node = tf.gradients(self.cost, self.variables) - with tf.name_scope('CROSSENTROPY') as scope: - self.cross_entropy = tf.reduce_mean(cross_entropy(tf.reshape(self.y, [-1, n_class]), - tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class]))) - with tf.name_scope('PREDICTER') as scope: - self.predicter = pixel_wise_softmax_2(logits) - self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3)) - self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) + with tf.name_scope('inputs') as scope: + self.x = tf.placeholder("float", shape=[None, None, None, channels]) + self.y = tf.placeholder("float", shape=[None, None, None, n_class]) + self.keep_prob = tf.placeholder(tf.float32) #dropout (keep probability) + + with tf.name_scope('CONVNET') as scope: + logits, self.variables, self.offset = create_conv_net(self.x, self.keep_prob, channels, n_class, **kwargs) + with tf.name_scope('COST') as scope: + self.cost = self._get_cost(logits, cost, cost_kwargs) + with tf.name_scope('GRADIENT') as scope: + self.gradients_node = tf.gradients(self.cost, self.variables) + with tf.name_scope('CROSSENTROPY') as scope: + self.cross_entropy = tf.reduce_mean(cross_entropy(tf.reshape(self.y, [-1, n_class]), + tf.reshape(pixel_wise_softmax_2(logits), [-1, n_class]))) + with tf.name_scope('PREDICTER') as scope: + self.predicter = pixel_wise_softmax_2(logits) + self.correct_pred = tf.equal(tf.argmax(self.predicter, 3), tf.argmax(self.y, 3)) + self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32)) def _get_cost(self, logits, cost_name, cost_kwargs): """ @@ -407,7 +408,7 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo Lauches the training process :param data_provider: callable returning training and verification data - :param output_path: path where to store checkpoints + :param output_path: path where to store model and trainPrediction :param training_iters: number of training mini batch iteration :param epochs: number of epochs :param dropout: dropout probability @@ -415,27 +416,30 @@ def train(self, data_provider, output_path, training_iters=10, epochs=100, dropo :param restore: Flag if previous model should be restored :param write_graph: Flag if the computation graph should be written as protobuf file to the output path """ - save_path = os.path.join(output_path, "model.cpkt") + + outputModel = output_path + '/model' + save_path = os.path.join(outputModel, "model.ckpt") + self.prediction_path = output_path + '/trainPrediction' if epochs == 0: return save_path - init = self._initialize(training_iters, output_path, restore) + init = self._initialize(training_iters, outputModel, restore) with tf.Session() as sess: if write_graph: - tf.train.write_graph(sess.graph_def, output_path, "graph.pb", False) + tf.train.write_graph(sess.graph_def, outputModel, "graph.pb", False) sess.run(init) if restore: - ckpt = tf.train.get_checkpoint_state(output_path) + ckpt = tf.train.get_checkpoint_state(outputModel) if ckpt and ckpt.model_checkpoint_path: self.net.restore(sess, ckpt.model_checkpoint_path) test_x, test_y = data_provider(self.verification_batch_size) pred_shape = self.store_prediction(sess, test_x, test_y, "_init") - summary_writer = tf.summary.FileWriter(output_path, graph=sess.graph) + summary_writer = tf.summary.FileWriter(outputModel, graph=sess.graph) logging.info("Start optimization") avg_gradients = None @@ -479,8 +483,8 @@ def store_prediction(self, sess, batch_x, batch_y, name): pred_shape = prediction.shape loss = sess.run(self.net.cost, feed_dict={self.net.x: batch_x, - self.net.y: util.crop_to_shape(batch_y, pred_shape), - self.net.keep_prob: 1.}) + self.net.y: util.crop_to_shape(batch_y, pred_shape), + self.net.keep_prob: 1.}) logging.info("Verification error= {:.1f}%, loss= {:.4f}".format(error_rate(prediction, util.crop_to_shape(batch_y, @@ -488,7 +492,7 @@ def store_prediction(self, sess, batch_x, batch_y, name): loss)) img = util.combine_img_prediction(batch_x, batch_y, prediction) - util.save_image(img, "%s/%s.jpg"%(self.prediction_path, name)) + util.save_image(img, "%s/%s.tif"%(self.prediction_path, name)) return pred_shape diff --git a/tf_unet/util.py b/tf_unet/util.py index a552dcc..5702617 100644 --- a/tf_unet/util.py +++ b/tf_unet/util.py @@ -54,6 +54,18 @@ def plot_prediction(x_test, y_test, prediction, save=False): fig.show() plt.show() + +def PlotSingleImagePrediction(prediction, output_path): + from PIL import Image + + pred = prediction[0, ..., 1] + pred -= np.amin(pred) + pred /= np.amax(pred) + + filename = output_path.split('.')[0] + '.tif' + img = Image.fromarray(pred) + img.save(filename) + def to_rgb(img): """ Converts the given array into a RGB image. If the number of channels is not @@ -110,5 +122,5 @@ def save_image(img, path): :param img: the rgb image to save :param path: the target path """ - Image.fromarray(img.round().astype(np.uint8)).save(path, 'JPEG', dpi=[300,300], quality=90) + Image.fromarray(img.round().astype(np.uint8)).save(path, 'PNG', dpi=[300,300], quality=100) From d874b8a1e72b69c9c309b13d9c34f835d48975c8 Mon Sep 17 00:00:00 2001 From: Jordan G Date: Tue, 23 May 2017 10:26:11 +0200 Subject: [PATCH 3/3] ENH : gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 636b2fd..52abb14 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,5 @@ docs/_build /.hope/ /.ipynb_checkpoints/ /.settings/ +*.iml +*.xml