Skip to content

Commit 2c94ed4

Browse files
authored
Merge pull request google#69 from ducklingll/master
use tf rewrite pca operation
2 parents 2bef29d + d65ca6a commit 2c94ed4

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

feature_extractor/feature_extractor.py

+16-12
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def __init__(self, model_dir=MODEL_DIR):
6464
if not os.path.exists(model_dir):
6565
os.makedirs(model_dir)
6666

67+
# Load PCA Matrix.
68+
download_path = self._maybe_download(YT8M_PCA_MAT)
69+
pca_mean = os.path.join(self._model_dir, 'mean.npy')
70+
if not os.path.exists(pca_mean):
71+
tarfile.open(download_path, 'r:gz').extractall(model_dir)
72+
self._load_pca()
73+
6774
# Load Inception Network
6875
download_path = self._maybe_download(INCEPTION_TF_GRAPH)
6976
inception_proto_file = os.path.join(
@@ -72,12 +79,7 @@ def __init__(self, model_dir=MODEL_DIR):
7279
tarfile.open(download_path, 'r:gz').extractall(model_dir)
7380
self._load_inception(inception_proto_file)
7481

75-
# Load PCA Matrix.
76-
download_path = self._maybe_download(YT8M_PCA_MAT)
77-
pca_mean = os.path.join(self._model_dir, 'mean.npy')
78-
if not os.path.exists(pca_mean):
79-
tarfile.open(download_path, 'r:gz').extractall(model_dir)
80-
self._load_pca()
82+
8183

8284
def extract_rgb_frame_features(self, frame_rgb, apply_pca=True):
8385
"""Applies the YouTube8M feature extraction over an RGB frame.
@@ -98,13 +100,8 @@ def extract_rgb_frame_features(self, frame_rgb, apply_pca=True):
98100
assert len(frame_rgb.shape) == 3
99101
assert frame_rgb.shape[2] == 3 # 3 channels (R, G, B)
100102
with self._inception_graph.as_default():
101-
frame_features = self.session.run('pool_3/_reshape:0',
103+
frame_features = self.session.run('pca_final_feature:0',
102104
feed_dict={'DecodeJpeg:0': frame_rgb})
103-
frame_features = frame_features[0] # Unbatch.
104-
105-
if apply_pca:
106-
frame_features = self.apply_pca(frame_features)
107-
108105
return frame_features
109106

110107
def apply_pca(self, frame_features):
@@ -148,6 +145,13 @@ def _load_inception(self, proto_file):
148145
with self._inception_graph.as_default():
149146
_ = tf.import_graph_def(graph_def, name='')
150147
self.session = tf.Session()
148+
Frame_Features = self.session.graph.get_tensor_by_name('pool_3/_reshape:0')
149+
Pca_Mean = tf.constant(value=self.pca_mean, dtype=tf.float32)
150+
Pca_Eigenvecs = tf.constant(value=self.pca_eigenvecs, dtype=tf.float32)
151+
Pca_Eigenvals = tf.constant(value=self.pca_eigenvals, dtype=tf.float32)
152+
Feats = Frame_Features[0] - Pca_Mean
153+
Feats = tf.reshape(tf.matmul(tf.reshape(Feats, [1, 2048]), Pca_Eigenvecs), [1024, ])
154+
tf.divide(Feats, tf.sqrt(Pca_Eigenvals + 1e-4), name='pca_final_feature')
151155

152156
def _load_pca(self):
153157
self.pca_mean = numpy.load(

0 commit comments

Comments
 (0)