-
Notifications
You must be signed in to change notification settings - Fork 162
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[83] Example to run a deep learning demo #99
Open
thunterdb
wants to merge
9
commits into
databricks:master
Choose a base branch
from
thunterdb:83-example
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
4ceee64
support for strings
thunterdb baf6542
removing type tags, they do not work with scala 2.10
thunterdb ed4d39c
removing type tags, they do not work with scala 2.10
thunterdb 46ce8d0
removing type tags, they do not work with scala 2.10
thunterdb 839c945
changes
thunterdb 2d0b497
changes
thunterdb 8a560f1
Merge remote-tracking branch 'upstream/master' into 83-example
thunterdb 7ab76b8
small change
thunterdb ac2b1e1
Merge remote-tracking branch 'upstream/master' into 83-example
thunterdb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
189 changes: 189 additions & 0 deletions
189
src/main/python/tensorframes_snippets/preparation_inceptionv3.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
from datetime import datetime | ||
import math | ||
import os.path | ||
import time | ||
|
||
|
||
import numpy as np | ||
import tensorflow as tf | ||
|
||
from preprocessing import inception_preprocessing | ||
import datasets.imagenet as imagenet | ||
from nets import inception | ||
import datasets.dataset_utils as dataset_utils | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.training import saver as tf_saver | ||
from tensorflow.python.framework import graph_util | ||
|
||
slim = tf.contrib.slim | ||
|
||
default_image_size = 299 | ||
|
||
|
||
####### Download the network data | ||
# The URL of the checkpointed data. | ||
url = "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz" | ||
# The name of the checkpoint file: | ||
checkpoint_file = 'inception_v3.ckpt' | ||
# Specify where you want to download the model to | ||
checkpoints_dir = '/tmp/checkpoints' | ||
|
||
checkpoint_path = os.path.join(checkpoints_dir, checkpoint_file) | ||
if not tf.gfile.Exists(checkpoints_dir): | ||
tf.gfile.MakeDirs(checkpoints_dir) | ||
|
||
if not tf.gfile.Exists(checkpoint_path): | ||
print('Downloading the model...') | ||
dataset_utils.download_and_uncompress_tarball(url, checkpoints_dir) | ||
|
||
#### TEST TO REMOVE | ||
|
||
s = tf.constant("This is string") | ||
r = tf.decode_raw(s, tf.int8) | ||
s2 = tf.as_string(r) | ||
sess = tf.InteractiveSession() | ||
print(s2.eval()) | ||
|
||
###### Building the computation graph | ||
|
||
# All this code can be run once. It assembles the computation graph, fills it with the checkpointed | ||
# coefficients, and then saves it as a protocol buffer description. | ||
|
||
# Build the graph | ||
g = tf.Graph() | ||
with g.as_default(): | ||
# Keep for now a placeholder that will eventually be filled with the content of the image. | ||
# This code only accepts JPEG images, which is the most common image format. | ||
image_string = tf.placeholder(tf.string, [], name="image_input") | ||
|
||
# Decode string into matrix with intensity values | ||
image = tf.image.decode_jpeg(image_string, channels=3) | ||
|
||
# Resize the input image, preserving the aspect ratio | ||
# and make a central crop of the resulted image. | ||
# The crop will be of the size of the default image size of | ||
# the network. | ||
processed_image = inception_preprocessing.preprocess_image(image, | ||
default_image_size, | ||
default_image_size, | ||
is_training=False) | ||
|
||
# Networks accept images in batches. | ||
# The first dimension usually represents the batch size. | ||
# In our case the batch size is one. | ||
processed_images = tf.expand_dims(processed_image, 0) | ||
|
||
# Create the model, use the default arg scope to configure | ||
# the batch norm parameters. arg_scope is a very conveniet | ||
# feature of slim library -- you can define default | ||
# parameters for layers -- like stride, padding etc. | ||
# Note: like the Arabian nights, inception defines 1001 classes | ||
# to include a background class (the first). | ||
with slim.arg_scope(inception.inception_v3_arg_scope()): | ||
logits, _ = inception.inception_v3(processed_images, | ||
num_classes=1001, | ||
is_training=False) | ||
|
||
# In order to get probabilities we apply softmax on the output. | ||
probabilities = tf.nn.softmax(logits) | ||
|
||
# Just focus on the top predictions | ||
top_pred = tf.nn.top_k(tf.squeeze(probabilities), k=5, name="top_predictions") | ||
|
||
# These are the outputs we will be requesting from the network. | ||
output_nodes = [probabilities, top_pred.indices, top_pred.values] | ||
|
||
# Create the saver | ||
with g.as_default(): | ||
model_variables = slim.get_model_variables('InceptionV3') | ||
saver = tf_saver.Saver(model_variables, reshape=False) | ||
|
||
def get_op_name(tensor): | ||
return tensor.name.split(":")[0] | ||
|
||
# Export the network | ||
with g.as_default(): | ||
with tf.Session() as sess: | ||
saver.restore(sess, checkpoint_path) | ||
# The add_shapes option is important: Spark requires this extra shape information to infor the | ||
# correct types. | ||
input_graph_def = g.as_graph_def(add_shapes=True) | ||
output_tensor_names = [node.name for node in output_nodes] | ||
output_node_names = [n.split(":")[0] for n in output_tensor_names] | ||
output_graph_def = graph_util.convert_variables_to_constants( | ||
sess, | ||
input_graph_def, | ||
output_node_names, | ||
variable_names_blacklist=[]) | ||
|
||
# The variable 'output_graph_def' now contains all the description of the computation. | ||
# The variables in the 'output_nodes' list will be used to know what to output. | ||
|
||
####### Testing the computation graph | ||
|
||
# This code performs a sanity check, by running the network against some image content downloaded from the internet. | ||
|
||
g2 = tf.Graph() | ||
with g2.as_default(): | ||
tf.import_graph_def(output_graph_def, name='') | ||
|
||
#### Download an image | ||
import requests | ||
|
||
# Example picture: | ||
# Specify where you want to download the model to | ||
images_dir = '/tmp/image_data' | ||
|
||
if not tf.gfile.Exists(images_dir): | ||
tf.gfile.MakeDirs(images_dir) | ||
|
||
image_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/WeaverAntDefense.JPG/640px-WeaverAntDefense.JPG' | ||
image_url = 'https://www.tensorflow.org/images/cropped_panda.jpg' | ||
image_path = os.path.join(images_dir, image_url.split('/')[-1]) | ||
|
||
if not tf.gfile.Exists(image_path): | ||
response = requests.get(image_url) | ||
if response.status_code == 200: | ||
with open(image_path, 'wb') as f: | ||
f.write(response.content) | ||
|
||
image_data = tf.gfile.FastGFile(image_path, 'rb').read() | ||
|
||
with g2.as_default(): | ||
input_node2 = g2.get_operation_by_name(get_op_name(image)) | ||
output_nodes2 = [g2.get_tensor_by_name(n) for n in output_tensor_names] | ||
with tf.Session() as sess: | ||
(probabilities_, indices_, values_) = sess.run(output_nodes2, {'image_input:0':image_data}) | ||
|
||
names = imagenet.create_readable_names_for_imagenet_labels() | ||
for i in range(5): | ||
index = indices_[i] | ||
print('Probability %d %0.2f => [%s]' % (index, values_[i], names[index])) | ||
|
||
|
||
###### Perform some evaluation with TensorFrames | ||
|
||
# This code takes the network and a directory that contains some image content. It shows how to process the content | ||
# using Spark dataframes and Tensorframes. | ||
|
||
import tensorframes as tfs | ||
sc.setLogLevel('INFO') | ||
|
||
raw_images_miscast = sc.binaryFiles("file:"+images_dir) | ||
raw_images = raw_images_miscast.map(lambda x: (x[0], bytearray(x[1]))) | ||
|
||
df = spark.createDataFrame(raw_images).toDF('image_uri', 'image_data') | ||
df | ||
|
||
with g2.as_default(): | ||
index_output = tf.identity(g2.get_tensor_by_name('top_predictions:1'), name="index") | ||
value_output = tf.identity(g2.get_tensor_by_name('top_predictions:0'), name="value") | ||
pred_df = tfs.map_rows([index_output, value_output], df, feed_dict={'image_input':'image_data'}) | ||
|
||
pred_df.select('index', 'value').show() | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove?