@@ -64,6 +64,13 @@ def __init__(self, model_dir=MODEL_DIR):
64
64
if not os .path .exists (model_dir ):
65
65
os .makedirs (model_dir )
66
66
67
+ # Load PCA Matrix.
68
+ download_path = self ._maybe_download (YT8M_PCA_MAT )
69
+ pca_mean = os .path .join (self ._model_dir , 'mean.npy' )
70
+ if not os .path .exists (pca_mean ):
71
+ tarfile .open (download_path , 'r:gz' ).extractall (model_dir )
72
+ self ._load_pca ()
73
+
67
74
# Load Inception Network
68
75
download_path = self ._maybe_download (INCEPTION_TF_GRAPH )
69
76
inception_proto_file = os .path .join (
@@ -72,12 +79,7 @@ def __init__(self, model_dir=MODEL_DIR):
72
79
tarfile .open (download_path , 'r:gz' ).extractall (model_dir )
73
80
self ._load_inception (inception_proto_file )
74
81
75
- # Load PCA Matrix.
76
- download_path = self ._maybe_download (YT8M_PCA_MAT )
77
- pca_mean = os .path .join (self ._model_dir , 'mean.npy' )
78
- if not os .path .exists (pca_mean ):
79
- tarfile .open (download_path , 'r:gz' ).extractall (model_dir )
80
- self ._load_pca ()
82
+
81
83
82
84
def extract_rgb_frame_features (self , frame_rgb , apply_pca = True ):
83
85
"""Applies the YouTube8M feature extraction over an RGB frame.
@@ -98,13 +100,8 @@ def extract_rgb_frame_features(self, frame_rgb, apply_pca=True):
98
100
assert len (frame_rgb .shape ) == 3
99
101
assert frame_rgb .shape [2 ] == 3 # 3 channels (R, G, B)
100
102
with self ._inception_graph .as_default ():
101
- frame_features = self .session .run ('pool_3/_reshape :0' ,
103
+ frame_features = self .session .run ('pca_final_feature :0' ,
102
104
feed_dict = {'DecodeJpeg:0' : frame_rgb })
103
- frame_features = frame_features [0 ] # Unbatch.
104
-
105
- if apply_pca :
106
- frame_features = self .apply_pca (frame_features )
107
-
108
105
return frame_features
109
106
110
107
def apply_pca (self , frame_features ):
@@ -148,6 +145,13 @@ def _load_inception(self, proto_file):
148
145
with self ._inception_graph .as_default ():
149
146
_ = tf .import_graph_def (graph_def , name = '' )
150
147
self .session = tf .Session ()
148
+ Frame_Features = self .session .graph .get_tensor_by_name ('pool_3/_reshape:0' )
149
+ Pca_Mean = tf .constant (value = self .pca_mean , dtype = tf .float32 )
150
+ Pca_Eigenvecs = tf .constant (value = self .pca_eigenvecs , dtype = tf .float32 )
151
+ Pca_Eigenvals = tf .constant (value = self .pca_eigenvals , dtype = tf .float32 )
152
+ Feats = Frame_Features [0 ] - Pca_Mean
153
+ Feats = tf .reshape (tf .matmul (tf .reshape (Feats , [1 , 2048 ]), Pca_Eigenvecs ), [1024 , ])
154
+ tf .divide (Feats , tf .sqrt (Pca_Eigenvals + 1e-4 ), name = 'pca_final_feature' )
151
155
152
156
def _load_pca (self ):
153
157
self .pca_mean = numpy .load (
0 commit comments