Skip to content

Commit 3160dad

Browse files
vicaireLeegleechN
authored andcommitted
Add support for batch prediction. (google#37)
1 parent 568dada commit 3160dad

File tree

5 files changed

+363
-24
lines changed

5 files changed

+363
-24
lines changed

README.md

+61
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ or on your own machine. This README provides instructions for both.
1717
* [Testing Locally](#testing-locally)
1818
* [Training on the Cloud over Video-Level Features](#training-on-video-level-features)
1919
* [Evaluation and Inference](#evaluation-and-inference)
20+
* [Inference Using Batch Prediction](#inference-using-batch-prediction)
2021
* [Accessing Files on Google Cloud](#accessing-files-on-google-cloud)
2122
* [Using Frame-Level Features](#using-frame-level-features)
2223
* [Using Audio Features](#using-audio-features)
@@ -187,6 +188,62 @@ and the following for the inference code:
187188
num examples processed: 8192 elapsed seconds: 14.85
188189
```
189190

191+
### Inference Using Batch Prediction
192+
To perform inference faster, you can also use the Cloud ML batch prediction
193+
service.
194+
195+
First, find the directory where the training job exported the model:
196+
197+
```
198+
gsutil list ${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export
199+
```
200+
201+
You should see an output similar to this one:
202+
203+
```
204+
${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export/
205+
${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export/step_1/
206+
${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export/step_1001/
207+
${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export/step_2001/
208+
${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export/step_3001/
209+
```
210+
211+
Select the latest version of the model that was saved. For instance, in our
212+
case, we select the version of the model that was saved at step 3001:
213+
214+
```
215+
EXPORTED_MODEL_DIR=${BUCKET_NAME}/yt8m_train_video_level_logistic_model/export/step_3001/
216+
```
217+
218+
Start the batch prediction job using the following command:
219+
220+
```
221+
JOB_NAME=yt8m_batch_predict_$(date +%Y%m%d_%H%M%S); \
222+
gcloud beta ml jobs submit prediction ${JOB_NAME} --verbosity=debug \
223+
--model-dir=${EXPORTED_MODEL_DIR} --data-format=TF_RECORD \
224+
--input-paths=gs://youtube8m-ml/1/video_level/test/test* \
225+
--output-path=${BUCKET_NAME}/batch_predict/${JOB_NAME} --region=us-east1 \
226+
--runtime-version=1.0 --max-worker-count=10
227+
```
228+
229+
You can check the progress of the job on the
230+
[Google Cloud ML Jobs console](https://console.cloud.google.com/ml/jobs). To
231+
have the job complete faster, you can increase 'max-worker-count' to a
232+
higher value.
233+
234+
Once the batch prediction job has completed, turn its output into a submission
235+
in the CVS format by running the following commands:
236+
237+
```
238+
# Copy the output of the batch prediction job to a local directory
239+
mkdir -p /tmp/batch_predict/${JOB_NAME}
240+
gsutil -m cp -r ${BUCKET_NAME}/batch_predict/${JOB_NAME}/* /tmp/batch_predict/${JOB_NAME}/
241+
242+
# Convert the output of the batch prediction job into a CVS file ready for submission
243+
python youtube-8m/convert_prediction_from_json_to_csv.py \
244+
--json_prediction_files_pattern="/tmp/batch_predict/${JOB_NAME}/prediction.results-*" \
245+
--csv_output_file="/tmp/batch_predict/${JOB_NAME}/output.csv"
246+
```
190247

191248
### Accessing Files on Google Cloud
192249

@@ -428,6 +485,8 @@ This sample code contains implementations of the models given in the
428485
level features as input.
429486
* `model_util.py`: Contains functions that are of general utility for
430487
implementing models.
488+
* `export_model.py`: Provides a class to export a model during training
489+
for later use in batch prediction.
431490
* `readers.py`: Contains definitions for the Video dataset and Frame
432491
dataset readers.
433492

@@ -446,6 +505,8 @@ This sample code contains implementations of the models given in the
446505
### Misc
447506
* `README.md`: This documentation.
448507
* `utils.py`: Common functions.
508+
* `convert_prediction_from_json_to_csv.py`: Converts the JSON output of
509+
batch prediction into a CSV file for submission.
449510

450511
## About This Project
451512
This project is meant help people quickly get started working with the
+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS-IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utility to convert the output of batch prediction into a CSV submission.
16+
17+
It converts the JSON files created by the command
18+
'gcloud beta ml jobs submit prediction' into a CSV file ready for submission.
19+
"""
20+
21+
import json
22+
import tensorflow as tf
23+
24+
from builtins import range
25+
from tensorflow import app
26+
from tensorflow import flags
27+
from tensorflow import gfile
28+
from tensorflow import logging
29+
30+
31+
FLAGS = flags.FLAGS
32+
33+
if __name__ == '__main__':
34+
35+
flags.DEFINE_string(
36+
"json_prediction_files_pattern", None,
37+
"Pattern specifying the list of JSON files that the command "
38+
"'gcloud beta ml jobs submit prediction' outputs. These files are "
39+
"located in the output path of the prediction command and are prefixed "
40+
"with 'prediction.results'.")
41+
flags.DEFINE_string(
42+
"csv_output_file", None,
43+
"The file to save the predictions converted to the CSV format.")
44+
45+
46+
def get_csv_header():
47+
return "VideoId,LabelConfidencePairs\n"
48+
49+
def to_csv_row(json_data):
50+
51+
video_id = json_data["video_id"]
52+
53+
class_indexes = json_data["class_indexes"]
54+
predictions = json_data["predictions"]
55+
56+
if isinstance(video_id, list):
57+
video_id = video_id[0]
58+
class_indexes = class_indexes[0]
59+
predictions = predictions[0]
60+
61+
if len(class_indexes) != len(predictions):
62+
raise ValueError(
63+
"The number of indexes (%s) and predictions (%s) must be equal."
64+
% (len(class_indexes), len(predictions)))
65+
66+
return (video_id.decode('utf-8') + "," + " ".join("%i %f" %
67+
(class_indexes[i], predictions[i])
68+
for i in range(len(class_indexes))) + "\n")
69+
70+
def main(unused_argv):
71+
logging.set_verbosity(tf.logging.INFO)
72+
73+
if not FLAGS.json_prediction_files_pattern:
74+
raise ValueError(
75+
"The flag --json_prediction_files_pattern must be specified.")
76+
77+
if not FLAGS.csv_output_file:
78+
raise ValueError("The flag --csv_output_file must be specified.")
79+
80+
logging.info("Looking for prediction files with pattern: %s",
81+
FLAGS.json_prediction_files_pattern)
82+
83+
file_paths = gfile.Glob(FLAGS.json_prediction_files_pattern)
84+
logging.info("Found files: %s", file_paths)
85+
86+
logging.info("Writing submission file to: %s", FLAGS.csv_output_file)
87+
with gfile.Open(FLAGS.csv_output_file, "w+") as output_file:
88+
output_file.write(get_csv_header())
89+
90+
for file_path in file_paths:
91+
logging.info("processing file: %s", file_path)
92+
93+
with gfile.Open(file_path) as input_file:
94+
95+
for line in input_file:
96+
json_data = json.loads(line)
97+
output_file.write(to_csv_row(json_data))
98+
99+
output_file.flush()
100+
logging.info("done")
101+
102+
if __name__ == "__main__":
103+
app.run()

export_model.py

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS-IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Utilities to export a model for batch prediction."""
15+
16+
import tensorflow as tf
17+
import tensorflow.contrib.slim as slim
18+
19+
from tensorflow.python.saved_model import builder as saved_model_builder
20+
from tensorflow.python.saved_model import signature_constants
21+
from tensorflow.python.saved_model import signature_def_utils
22+
from tensorflow.python.saved_model import tag_constants
23+
from tensorflow.python.saved_model import utils as saved_model_utils
24+
25+
_TOP_PREDICTIONS_IN_OUTPUT = 20
26+
27+
class ModelExporter(object):
28+
29+
def __init__(self, frame_features, model, reader):
30+
self.frame_features = frame_features
31+
self.model = model
32+
self.reader = reader
33+
34+
with tf.Graph().as_default() as graph:
35+
self.inputs, self.outputs = self.build_inputs_and_outputs()
36+
self.graph = graph
37+
self.saver = tf.train.Saver(tf.trainable_variables(), sharded=True)
38+
39+
def export_model(self, model_dir, global_step_val, last_checkpoint):
40+
"""Exports the model so that it can used for batch predictions."""
41+
42+
with self.graph.as_default():
43+
with tf.Session() as session:
44+
self.saver.restore(session, last_checkpoint)
45+
46+
signature = signature_def_utils.build_signature_def(
47+
inputs=self.inputs,
48+
outputs=self.outputs,
49+
method_name=signature_constants.PREDICT_METHOD_NAME)
50+
51+
signature_map = {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
52+
signature}
53+
54+
model_builder = saved_model_builder.SavedModelBuilder(model_dir)
55+
model_builder.add_meta_graph_and_variables(session,
56+
tags=[tag_constants.SERVING],
57+
signature_def_map=signature_map,
58+
clear_devices=True)
59+
model_builder.save()
60+
61+
def build_inputs_and_outputs(self):
62+
63+
if self.frame_features:
64+
65+
serialized_examples = tf.placeholder(tf.string, shape=(None,))
66+
67+
fn = lambda x: self.build_prediction_graph(x)
68+
video_id_output, top_indices_output, top_predictions_output = (
69+
tf.map_fn(fn, serialized_examples,
70+
dtype=(tf.string, tf.int32, tf.float32)))
71+
72+
else:
73+
74+
serialized_examples = tf.placeholder(tf.string, shape=(None,))
75+
76+
video_id_output, top_indices_output, top_predictions_output = (
77+
self.build_prediction_graph(serialized_examples))
78+
79+
inputs = {"example_bytes":
80+
saved_model_utils.build_tensor_info(serialized_examples)}
81+
82+
outputs = {
83+
"video_id": saved_model_utils.build_tensor_info(video_id_output),
84+
"class_indexes": saved_model_utils.build_tensor_info(top_indices_output),
85+
"predictions": saved_model_utils.build_tensor_info(top_predictions_output)}
86+
87+
return inputs, outputs
88+
89+
def build_prediction_graph(self, serialized_examples):
90+
91+
video_id, model_input_raw, labels_batch, num_frames = (
92+
self.reader.prepare_serialized_examples(serialized_examples))
93+
94+
feature_dim = len(model_input_raw.get_shape()) - 1
95+
model_input = tf.nn.l2_normalize(model_input_raw, feature_dim)
96+
97+
with tf.name_scope("model"):
98+
result = self.model.create_model(
99+
model_input,
100+
num_frames=num_frames,
101+
vocab_size=self.reader.num_classes,
102+
labels=labels_batch)
103+
104+
for variable in slim.get_model_variables():
105+
tf.summary.histogram(variable.op.name, variable)
106+
107+
predictions = result["predictions"]
108+
109+
top_predictions, top_indices = tf.nn.top_k(predictions,
110+
_TOP_PREDICTIONS_IN_OUTPUT)
111+
return video_id, top_indices, top_predictions

readers.py

+11
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ def prepare_reader(self, filename_queue, batch_size=1024):
103103
reader = tf.TFRecordReader()
104104
_, serialized_examples = reader.read_up_to(filename_queue, batch_size)
105105

106+
tf.add_to_collection("serialized_examples", serialized_examples)
107+
return self.prepare_serialized_examples(serialized_examples)
108+
109+
def prepare_serialized_examples(self, serialized_examples):
106110
# set the mapping from the fields to data types in the proto
107111
num_features = len(self.feature_names)
108112
assert num_features > 0, "self.feature_names is empty!"
@@ -117,6 +121,7 @@ def prepare_reader(self, filename_queue, batch_size=1024):
117121
[self.feature_sizes[feature_index]], tf.float32)
118122

119123
features = tf.parse_example(serialized_examples, features=feature_map)
124+
120125
labels = tf.sparse_to_indicator(features["labels"], self.num_classes)
121126
labels.set_shape([None, self.num_classes])
122127
concatenated_features = tf.concat([
@@ -203,6 +208,12 @@ def prepare_reader(self,
203208
reader = tf.TFRecordReader()
204209
_, serialized_example = reader.read(filename_queue)
205210

211+
return self.prepare_serialized_examples(serialized_example,
212+
max_quantized_value, min_quantized_value)
213+
214+
def prepare_serialized_examples(self, serialized_example,
215+
max_quantized_value=2, min_quantized_value=-2):
216+
206217
contexts, features = tf.parse_single_sequence_example(
207218
serialized_example,
208219
context_features={"video_id": tf.FixedLenFeature(

0 commit comments

Comments
 (0)