From bd96bf25a9e3f8e2d1da2e1f8d0d21a9596421f8 Mon Sep 17 00:00:00 2001 From: scheng123 Date: Mon, 29 Apr 2019 16:39:48 -0400 Subject: [PATCH] Added new features Implemented the function of saving SavedModel, checkpoint, and CTRL+C to save --- retrain.py | 130 ++++++++++++++++++++++++++++++++--------------------- 1 file changed, 80 insertions(+), 50 deletions(-) diff --git a/retrain.py b/retrain.py index 286796a..b8daab0 100644 --- a/retrain.py +++ b/retrain.py @@ -1,3 +1,4 @@ +import shutil # Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -866,56 +867,59 @@ def main(_): sess.run(init) # Run the training for as many cycles as requested on the command line. - for i in range(FLAGS.how_many_training_steps): - # Get a batch of input bottleneck values, either calculated fresh every - # time with distortions applied, or from the cache stored on disk. - if do_distort_images: - (train_bottlenecks, - train_ground_truth) = get_random_distorted_bottlenecks( - sess, image_lists, FLAGS.train_batch_size, 'training', - FLAGS.image_dir, distorted_jpeg_data_tensor, - distorted_image_tensor, resized_image_tensor, bottleneck_tensor) - else: - (train_bottlenecks, - train_ground_truth, _) = get_random_cached_bottlenecks( - sess, image_lists, FLAGS.train_batch_size, 'training', - FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, - bottleneck_tensor) - # Feed the bottlenecks and ground truth into the graph, and run a training - # step. Capture training summaries for TensorBoard with the `merged` op. - - train_summary, _ = sess.run( - [merged, train_step], - feed_dict={bottleneck_input: train_bottlenecks, - ground_truth_input: train_ground_truth}) - train_writer.add_summary(train_summary, i) - - # Every so often, print out how well the graph is training. - is_last_step = (i + 1 == FLAGS.how_many_training_steps) - if (i % FLAGS.eval_step_interval) == 0 or is_last_step: - train_accuracy, cross_entropy_value = sess.run( - [evaluation_step, cross_entropy], + try: + for i in range(FLAGS.how_many_training_steps): + # Get a batch of input bottleneck values, either calculated fresh every + # time with distortions applied, or from the cache stored on disk. + if do_distort_images: + (train_bottlenecks, + train_ground_truth) = get_random_distorted_bottlenecks( + sess, image_lists, FLAGS.train_batch_size, 'training', + FLAGS.image_dir, distorted_jpeg_data_tensor, + distorted_image_tensor, resized_image_tensor, bottleneck_tensor) + else: + (train_bottlenecks, + train_ground_truth, _) = get_random_cached_bottlenecks( + sess, image_lists, FLAGS.train_batch_size, 'training', + FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, + bottleneck_tensor) + # Feed the bottlenecks and ground truth into the graph, and run a training + # step. Capture training summaries for TensorBoard with the `merged` op. + + train_summary, _ = sess.run( + [merged, train_step], feed_dict={bottleneck_input: train_bottlenecks, ground_truth_input: train_ground_truth}) - print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i, - train_accuracy * 100)) - print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, - cross_entropy_value)) - validation_bottlenecks, validation_ground_truth, _ = ( - get_random_cached_bottlenecks( - sess, image_lists, FLAGS.validation_batch_size, 'validation', - FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, - bottleneck_tensor)) - # Run a validation step and capture training summaries for TensorBoard - # with the `merged` op. - validation_summary, validation_accuracy = sess.run( - [merged, evaluation_step], - feed_dict={bottleneck_input: validation_bottlenecks, - ground_truth_input: validation_ground_truth}) - validation_writer.add_summary(validation_summary, i) - print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' % - (datetime.now(), i, validation_accuracy * 100, - len(validation_bottlenecks))) + train_writer.add_summary(train_summary, i) + + # Every so often, print out how well the graph is training. + is_last_step = (i + 1 == FLAGS.how_many_training_steps) + if (i % FLAGS.eval_step_interval) == 0 or is_last_step: + train_accuracy, cross_entropy_value = sess.run( + [evaluation_step, cross_entropy], + feed_dict={bottleneck_input: train_bottlenecks, + ground_truth_input: train_ground_truth}) + print('%s: Step %d: Train accuracy = %.1f%%' % (datetime.now(), i, + train_accuracy * 100)) + print('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, + cross_entropy_value)) + validation_bottlenecks, validation_ground_truth, _ = ( + get_random_cached_bottlenecks( + sess, image_lists, FLAGS.validation_batch_size, 'validation', + FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, + bottleneck_tensor)) + # Run a validation step and capture training summaries for TensorBoard + # with the `merged` op. + validation_summary, validation_accuracy = sess.run( + [merged, evaluation_step], + feed_dict={bottleneck_input: validation_bottlenecks, + ground_truth_input: validation_ground_truth}) + validation_writer.add_summary(validation_summary, i) + print('%s: Step %d: Validation accuracy = %.1f%% (N=%d)' % + (datetime.now(), i, validation_accuracy * 100, + len(validation_bottlenecks))) + except KeyboardInterrupt: + print('CTRL+C caught, saving...') # We've completed all our training, so run a final test evaluation on # some new images we haven't used before. @@ -937,7 +941,21 @@ def main(_): if predictions[i] != test_ground_truth[i].argmax(): print('%70s %s' % (test_filename, list(image_lists.keys())[predictions[i]])) - + + # Add the checkpoint saving function + saver = tf.train.Saver() + save_path = saver.save(sess, FLAGS.output_checkpoint) + print("Checkpoint saved to %s" % save_path) + + # Add the SavedModel file saving function + if os.path.exists(FLAGS.output_savedmodel) and os.path.isdir(FLAGS.output_savedmodel): + shutil.rmtree(FLAGS.output_savedmodel) # SavedModel saver will throw error if dir exists + tf.saved_model.simple_save(sess, + FLAGS.output_savedmodel, + inputs={"%s" % JPEG_DATA_TENSOR_NAME: jpeg_data_tensor}, + outputs={"%s" % FLAGS.final_tensor_name: final_tensor}) + print("SavedModel saved to "+FLAGS.output_savedmodel) + # Write out the trained graph and labels with the weights stored as # constants. output_graph_def = graph_util.convert_variables_to_constants( @@ -962,6 +980,18 @@ def main(_): default='/tmp/output_graph.pb', help='Where to save the trained graph.' ) + parser.add_argument( + '--output_checkpoint', + type=str, + default='/tmp/checkpoint.ckpt', + help='Where to save the checkpoint.' + ) + parser.add_argument( + '--output_savedmodel', + type=str, + default='/tmp/saved_model.pb', + help='Where to save the SavedModel.' + ) parser.add_argument( '--output_labels', type=str, @@ -1102,4 +1132,4 @@ def main(_): """ ) FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) \ No newline at end of file + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)