Skip to content

Commit 6d770dd

Browse files
authored
Add Moe and DBoF models
1 parent b9adbb4 commit 6d770dd

12 files changed

+452
-86
lines changed

README.md

+36-3
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ This repo contains starter code for training and evaluating machine learning
44
models over the [YouTube-8M](https://research.google.com/youtube8m/) dataset.
55
The code gives an end-to-end working example for reading the dataset, training a
66
TensorFlow model, and evaluating the performance of the model. Out of the box,
7-
you can train a logistic classification model over either frame-level or
8-
video-level features. The code can be extended to train more complex models.
7+
you can train several [model architectures](#overview-of-models) over either
8+
frame-level or video-level features. The code can easily be extended to train
9+
your own custom-defined models.
910

1011
It is possible to train and evaluate on YouTube-8M in two ways: on your own
1112
machine, or on Google Cloud. This README provides instructions for both.
@@ -25,6 +26,9 @@ machine, or on Google Cloud. This README provides instructions for both.
2526
* [Using Frame-Level Features](#using-frame-level-features-1)
2627
* [Using Audio Features](#using-audio-features-1)
2728
* [Ground-Truth Label Files](#ground-truth-label-files)
29+
* [Overview of Models](#overview-of-models)
30+
* [Video-Level Models](#video-level-models)
31+
* [Frame-Level Models](#frame-level-models)
2832
* [Overview of Files](#overview-of-files)
2933
* [Training](#training)
3034
* [Evaluation](#evaluation)
@@ -85,6 +89,7 @@ JOB_NAME=yt8m_eval_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug beta ml jobs
8589
submit training $JOB_NAME \
8690
--package-path=youtube-8m --module-name=youtube-8m.eval \
8791
--staging-bucket=$BUCKET_NAME --region=us-central1 \
92+
--config=youtube-8m/cloudml-gpu.yaml \
8893
-- --eval_data_pattern='gs://youtube8m-ml/1/video_level/validate/validate*.tfrecord' \
8994
--train_dir=$BUCKET_NAME/${JOB_TO_EVAL}
9095
```
@@ -97,6 +102,7 @@ JOB_NAME=yt8m_inference_$(date +%Y%m%d_%H%M%S); gcloud --verbosity=debug beta ml
97102
submit training $JOB_NAME \
98103
--package-path=youtube-8m --module-name=youtube-8m.inference \
99104
--staging-bucket=$BUCKET_NAME --region=us-central1 \
105+
--config=youtube-8m/cloudml-gpu.yaml \
100106
-- --input_data_pattern='gs://youtube8m-ml/1/video_level/test/test*.tfrecord' \
101107
--train_dir=$BUCKET_NAME/${JOB_TO_EVAL} \
102108
--output_file=$BUCKET_NAME/${JOB_TO_EVAL}/predictions.csv
@@ -314,12 +320,39 @@ id 'VIDEO_ID' and two lables 'LABLE1' and 'LABEL2' we store the following line:
314320
VIDEO_ID,LABEL1 LABEL2
315321
```
316322

323+
## Overview of Models
324+
325+
This sample code contains implementations of three of the models given in the
326+
[YouTube-8M technical report](https://arxiv.org/abs/1609.08675).
327+
328+
### Video-Level Models
329+
* `LogisticModel`: Linear projection of the output features into the label
330+
space, followed by a sigmoid function to convert logit
331+
values to probabilities.
332+
* `MoeModel`: A per-class softmax distribution over a configurable number of
333+
logistic classifiers. One of the classifiers in the mixture
334+
is not trained, and always predicts 0.
335+
336+
### Frame-Level Models
337+
* `DBoFModel`: Projects the features for each frame into a higher dimensional
338+
'clustering' space, pools across frames in that space, and then
339+
uses a video-level model to classify the now aggregated features.
340+
* `FrameLevelLogisticModel`: Equivalent to 'LogisticModel', but performs
341+
average-pooling on the fly over frame-level
342+
features rather than using pre-aggregated features.
343+
317344
## Overview of Files
318345

319346
### Training
320347
* `train.py`: The primary script for training models.
321348
* `losses.py`: Contains definitions for loss functions.
322-
* `models.py`: Contains definitions for models.
349+
* `models.py`: Contains the base class for defining a model.
350+
* `video_level_models.py`: Contains definitions for models that take
351+
aggregated features as input.
352+
* `frame_level_models.py`: Contains definitions for models that take frame-
353+
level features as input.
354+
* `model_util.py`: Contains functions that are of general utility for
355+
implementing models.
323356
* `readers.py`: Contains definitions for the Video dataset and Frame
324357
dataset readers.
325358

cloudml-4gpu.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
trainingInput:
2+
scaleTier: CUSTOM
3+
# standard_gpu provides 1 GPU. Change to complex_model_m_gpu for 4 GPUs
4+
masterType: complex_model_m_gpu
5+
args: ["--num_gpus", "4"]
6+
runtimeVersion: "1.0"

cloudml-gpu.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ trainingInput:
22
scaleTier: CUSTOM
33
# standard_gpu provides 1 GPU. Change to complex_model_m_gpu for 4 GPUs
44
masterType: standard_gpu
5-
# runtimeVersion: "0.12"
5+
runtimeVersion: "1.0"

eval.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
import eval_util
1919
import losses
20-
import models
20+
import frame_level_models
21+
import video_level_models
2122
import readers
2223
import tensorflow as tf
2324
from tensorflow import app
@@ -282,7 +283,8 @@ def evaluate():
282283
reader = readers.YT8MAggregatedFeatureReader(feature_names=feature_names,
283284
feature_sizes=feature_sizes)
284285

285-
model = find_class_by_name(FLAGS.model, [models])()
286+
model = find_class_by_name(FLAGS.model,
287+
[frame_level_models, video_level_models])()
286288
label_loss_fn = find_class_by_name(FLAGS.label_loss, [losses])()
287289

288290
if FLAGS.eval_data_pattern is "":

frame_level_models.py

+193
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS-IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Contains a collection of models which operate on variable-length sequences.
16+
"""
17+
import math
18+
19+
import models
20+
import video_level_models
21+
import tensorflow as tf
22+
import model_utils as utils
23+
24+
import tensorflow.contrib.slim as slim
25+
from tensorflow import flags
26+
27+
FLAGS = flags.FLAGS
28+
flags.DEFINE_integer("iterations", 30,
29+
"Number of frames per batch for DBoF.")
30+
flags.DEFINE_bool("dbof_add_batch_norm", True,
31+
"Adds batch normalization to the DBoF model.")
32+
flags.DEFINE_bool(
33+
"sample_random_frames", True,
34+
"If true samples random frames (for frame level models). If false, a random"
35+
"sequence of frames is sampled instead.")
36+
flags.DEFINE_integer("dbof_cluster_size", 8192,
37+
"Number of units in the DBoF cluster layer.")
38+
flags.DEFINE_integer("dbof_hidden_size", 1024,
39+
"Number of units in the DBoF hidden layer.")
40+
flags.DEFINE_string("dbof_pooling_method", "max",
41+
"The pooling method used in the DBoF cluster layer. "
42+
"Choices are 'average' and 'max'.")
43+
flags.DEFINE_string("video_level_classifier_model", "MoeModel",
44+
"Some Frame-Level models can be decomposed into a "
45+
"generalized pooling operation followed by a "
46+
"classifier layer")
47+
48+
class FrameLevelLogisticModel(models.BaseModel):
49+
50+
def create_model(self, model_input, vocab_size, num_frames, **unused_params):
51+
"""Creates a model which uses a logistic classifier over the average of the
52+
frame-level features.
53+
54+
This class is intended to be an example for implementors of frame level
55+
models. If you want to train a model over averaged features it is more
56+
efficient to average them beforehand rather than on the fly.
57+
58+
Args:
59+
model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
60+
input features.
61+
vocab_size: The number of classes in the dataset.
62+
num_frames: A vector of length 'batch' which indicates the number of
63+
frames for each video (before padding).
64+
65+
Returns:
66+
A dictionary with a tensor containing the probability predictions of the
67+
model in the 'predictions' key. The dimensions of the tensor are
68+
'batch_size' x 'num_classes'.
69+
"""
70+
num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
71+
feature_size = model_input.get_shape().as_list()[2]
72+
73+
denominators = tf.reshape(
74+
tf.tile(num_frames, [1, feature_size]), [-1, feature_size])
75+
avg_pooled = tf.reduce_sum(model_input,
76+
axis=[1]) / denominators
77+
78+
output = slim.fully_connected(
79+
avg_pooled, vocab_size, activation_fn=tf.nn.sigmoid,
80+
weights_regularizer=slim.l2_regularizer(0.01))
81+
return {"predictions": output}
82+
83+
class DBoFModel(models.BaseModel):
84+
"""Creates a Deep Bag of Frames model.
85+
86+
The model projects the features for each frame into a higher dimensional
87+
'clustering' space, pools across frames in that space, and then
88+
uses a configurable video-level model to classify the now aggregated features.
89+
90+
The model will randomly sample either frames or sequences of frames during
91+
training to speed up convergence.
92+
93+
Args:
94+
model_input: A 'batch_size' x 'max_frames' x 'num_features' matrix of
95+
input features.
96+
vocab_size: The number of classes in the dataset.
97+
num_frames: A vector of length 'batch' which indicates the number of
98+
frames for each video (before padding).
99+
100+
Returns:
101+
A dictionary with a tensor containing the probability predictions of the
102+
model in the 'predictions' key. The dimensions of the tensor are
103+
'batch_size' x 'num_classes'.
104+
"""
105+
106+
def create_model(self,
107+
model_input,
108+
vocab_size,
109+
num_frames,
110+
iterations=None,
111+
add_batch_norm=None,
112+
sample_random_frames=None,
113+
cluster_size=None,
114+
hidden_size=None,
115+
is_training=True,
116+
**unused_params):
117+
iterations = iterations or FLAGS.iterations
118+
add_batch_norm = add_batch_norm or FLAGS.dbof_add_batch_norm
119+
random_frames = sample_random_frames or FLAGS.sample_random_frames
120+
cluster_size = cluster_size or FLAGS.dbof_cluster_size
121+
hidden1_size = hidden_size or FLAGS.dbof_hidden_size
122+
123+
num_frames = tf.cast(tf.expand_dims(num_frames, 1), tf.float32)
124+
if random_frames:
125+
model_input = utils.SampleRandomFrames(model_input, num_frames,
126+
iterations)
127+
else:
128+
model_input = utils.SampleRandomSequence(model_input, num_frames,
129+
iterations)
130+
max_frames = model_input.get_shape().as_list()[1]
131+
feature_size = model_input.get_shape().as_list()[2]
132+
reshaped_input = tf.reshape(model_input, [-1, feature_size])
133+
tf.summary.histogram("input_hist", reshaped_input)
134+
135+
if add_batch_norm:
136+
reshaped_input = slim.batch_norm(
137+
reshaped_input,
138+
center=True,
139+
scale=True,
140+
is_training=is_training,
141+
scope="input_bn")
142+
143+
cluster_weights = tf.Variable(tf.random_normal(
144+
[feature_size, cluster_size],
145+
stddev=1 / math.sqrt(feature_size)))
146+
tf.summary.histogram("cluster_weights", cluster_weights)
147+
activation = tf.matmul(reshaped_input, cluster_weights)
148+
if add_batch_norm:
149+
activation = slim.batch_norm(
150+
activation,
151+
center=True,
152+
scale=True,
153+
is_training=is_training,
154+
scope="cluster_bn")
155+
else:
156+
cluster_biases = tf.Variable(
157+
tf.random_normal(
158+
[cluster_size], stddev=1 / math.sqrt(feature_size)))
159+
tf.summary.histogram("cluster_biases", cluster_biases)
160+
activation += cluster_biases
161+
activation = tf.nn.relu6(activation)
162+
tf.summary.histogram("cluster_output", activation)
163+
164+
activation = tf.reshape(activation, [-1, max_frames, cluster_size])
165+
activation = utils.FramePooling(activation, FLAGS.dbof_pooling_method)
166+
167+
hidden1_weights = tf.Variable(tf.random_normal(
168+
[cluster_size, hidden1_size],
169+
stddev=1 / math.sqrt(cluster_size)))
170+
tf.summary.histogram("hidden1_weights", hidden1_weights)
171+
activation = tf.matmul(activation, hidden1_weights)
172+
if add_batch_norm:
173+
activation = slim.batch_norm(
174+
activation,
175+
center=True,
176+
scale=True,
177+
is_training=is_training,
178+
scope="hidden1_bn")
179+
else:
180+
hidden1_biases = tf.Variable(
181+
tf.random_normal(
182+
[hidden1_size], stddev=0.01))
183+
tf.summary.histogram("hidden1_biases", hidden1_biases)
184+
activation += hidden1_biases
185+
activation = tf.nn.relu6(activation)
186+
tf.summary.histogram("hidden1_output", activation)
187+
188+
aggregated_model = getattr(video_level_models,
189+
FLAGS.video_level_classifier_model)
190+
return aggregated_model().create_model(
191+
model_input=activation,
192+
vocab_size=vocab_size,
193+
**unused_params)

inference.py

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import losses
3030
import readers
3131
import utils
32-
import models
3332

3433
FLAGS = flags.FLAGS
3534

0 commit comments

Comments
 (0)