Skip to content

Commit 82104c5

Browse files
committed
Update inference and readers to match the new dataset and Kaggle requirements
1 parent 2a652b0 commit 82104c5

File tree

2 files changed

+13
-13
lines changed

2 files changed

+13
-13
lines changed

inference.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def format_lines(video_ids, predictions, top_k):
7575
line = [(class_index, predictions[video_index][class_index])
7676
for class_index in top_indices]
7777
line = sorted(line, key=lambda p: -p[1])
78-
yield video_ids[video_index].decode('utf-8') + "," + " ".join("%i %f" % pair
79-
for pair in line) + "\n"
78+
yield video_ids[video_index].decode('utf-8') + "," + " ".join(
79+
"%i" % label for (label, _) in line) + "\n"
8080

8181

8282
def get_input_data_tensors(reader, data_pattern, batch_size, num_readers=1):
@@ -152,7 +152,7 @@ def set_up_init_ops(variables):
152152
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
153153
num_examples_processed = 0
154154
start_time = time.time()
155-
out_file.write("VideoId,LabelConfidencePairs\n")
155+
out_file.write("VideoId,Labels\n")
156156

157157
try:
158158
while not coord.should_stop():

readers.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ class YT8MAggregatedFeatureReader(BaseReader):
7272
"""
7373

7474
def __init__(self,
75-
num_classes=4716,
76-
feature_sizes=[1024],
77-
feature_names=["mean_inc3"]):
75+
num_classes=3862,
76+
feature_sizes=[1024, 128],
77+
feature_names=["mean_rgb", "mean_audio"]):
7878
"""Construct a YT8MAggregatedFeatureReader.
7979
8080
Args:
@@ -114,7 +114,7 @@ def prepare_serialized_examples(self, serialized_examples):
114114
"length of feature_names (={}) != length of feature_sizes (={})".format( \
115115
len(self.feature_names), len(self.feature_sizes))
116116

117-
feature_map = {"video_id": tf.FixedLenFeature([], tf.string),
117+
feature_map = {"id": tf.FixedLenFeature([], tf.string),
118118
"labels": tf.VarLenFeature(tf.int64)}
119119
for feature_index in range(num_features):
120120
feature_map[self.feature_names[feature_index]] = tf.FixedLenFeature(
@@ -126,7 +126,7 @@ def prepare_serialized_examples(self, serialized_examples):
126126
concatenated_features = tf.concat([
127127
features[feature_name] for feature_name in self.feature_names], 1)
128128

129-
return features["video_id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]])
129+
return features["id"], concatenated_features, labels, tf.ones([tf.shape(serialized_examples)[0]])
130130

131131
class YT8MFrameFeatureReader(BaseReader):
132132
"""Reads TFRecords of SequenceExamples.
@@ -138,9 +138,9 @@ class YT8MFrameFeatureReader(BaseReader):
138138
"""
139139

140140
def __init__(self,
141-
num_classes=4716,
142-
feature_sizes=[1024],
143-
feature_names=["inc3"],
141+
num_classes=3862,
142+
feature_sizes=[1024, 128],
143+
feature_names=["rgb", "audio"],
144144
max_frames=300):
145145
"""Construct a YT8MFrameFeatureReader.
146146
@@ -215,7 +215,7 @@ def prepare_serialized_examples(self, serialized_example,
215215

216216
contexts, features = tf.parse_single_sequence_example(
217217
serialized_example,
218-
context_features={"video_id": tf.FixedLenFeature(
218+
context_features={"id": tf.FixedLenFeature(
219219
[], tf.string),
220220
"labels": tf.VarLenFeature(tf.int64)},
221221
sequence_features={
@@ -261,7 +261,7 @@ def prepare_serialized_examples(self, serialized_example,
261261

262262
# convert to batch format.
263263
# TODO: Do proper batch reads to remove the IO bottleneck.
264-
batch_video_ids = tf.expand_dims(contexts["video_id"], 0)
264+
batch_video_ids = tf.expand_dims(contexts["id"], 0)
265265
batch_video_matrix = tf.expand_dims(video_matrix, 0)
266266
batch_labels = tf.expand_dims(labels, 0)
267267
batch_frames = tf.expand_dims(num_frames, 0)

0 commit comments

Comments
 (0)