-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathinference.py
153 lines (126 loc) · 5.91 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
import progressbar
import logging
import logging.config
import os
import tensorflow as tf
import numpy as np
from model.resnet import ResNet
from dataset.voc_loader import VOCLoader
from dataset.instance_sampler import InstanceSampler
from utils.utils_tf import fill_and_crop
from configs.paths import CKPT_DIR, RAW_CONTEXT_DIR
from configs.config import (get_logging_config, args,
std_data_augmentation_config)
# import matplotlib
# matplotlib.rcParams['backend'] = "Qt4Agg"
slim = tf.contrib.slim
logging.config.dictConfig(get_logging_config(args.run_name))
log = logging.getLogger()
class InferenceModel(object):
def __init__(self, sess, net, sampler, img_size,
folder=None, context_estimation=False):
self.sess = sess
self.net = net
self.sampler = sampler
self.img_size = img_size
self.build_context_estimator()
def restore_from_ckpt(self, ckpt):
ckpt_path = os.path.join(CKPT_DIR, args.run_name, 'model.ckpt-%i000' % ckpt)
log.debug("Restoring checkpoint %s" % ckpt_path)
self.sess.run(tf.local_variables_initializer())
saver = tf.train.Saver(tf.global_variables())
saver.restore(self.sess, ckpt_path)
def build_context_estimator(self):
b = args.test_batch_size
self.images_ph = tf.placeholder(shape=[b, None, None, 3],
dtype=tf.float32, name='img_ph')
self.bboxes_ph = tf.placeholder(shape=[b, 4],
dtype=tf.float32, name='bboxes_ph')
self.frames_ph = tf.placeholder(shape=[b, 4],
dtype=tf.float32, name='frames_ph')
self.ws = tf.placeholder(shape=[b],
dtype=tf.float32, name='ws_ph')
self.hs = tf.placeholder(shape=[b],
dtype=tf.float32, name='hs_ph')
def fn(x):
return fill_and_crop(x[0], x[1], x[2], x[3], x[4],
std_data_augmentation_config)
imgs = tf.map_fn(fn, [self.images_ph, self.bboxes_ph,
self.frames_ph, self.ws, self.hs],
tf.float32, parallel_iterations=4, back_prop=False)
self.logits = self.net.build_net(imgs, self.sampler.num_classes)
self.output_probs = tf.nn.softmax(self.logits)
def estimate_context(self, imgs, bboxes, frames, ws, hs):
final_probs = []
b = args.test_batch_size
n_iters = int(np.floor(imgs.shape[0] / b))
for i in range(n_iters):
inds = np.arange(b*i, b*(i + 1))
feed_dict = {self.images_ph: imgs[inds],
self.bboxes_ph: bboxes[inds],
self.frames_ph: frames[inds],
self.ws: ws[inds],
self.hs: hs[inds]}
probs = self.sess.run(self.output_probs, feed_dict=feed_dict)
final_probs.append(probs)
final_probs = np.concatenate(final_probs, axis=0)
final_bboxes = np.array(bboxes * np.vstack([ws, hs, ws, hs]).T,
dtype=int)
# If sampling more than one context image per box
# averaging scores over them
if args.n_neighborhoods > 1:
nn = args.n_neighbors
final_bboxes = final_bboxes[::nn]
all_probs = [final_probs['probs'][i::nn] for i in range(nn)]
final_probs = np.stack(all_probs, -1).mean(-1)
return final_probs, final_bboxes
def sample2batch(s):
N = s['bboxes'].shape[0]
ws = np.array([s['w']] * N)
hs = np.array([s['h']] * N)
imgs = np.array([s['img']] * N)
return imgs, s['bboxes'], s['frames'], ws, hs
def main(argv=None): # pylint: disable=unused-argument
assert args.ckpt > 0
assert args.test_n % args.n_neighborhoods == 0, "test_n has to be a multiple of n_neighborhoods"
net = ResNet
net = net(training=False)
# extracting cats to exclude
excluded = [int(c) for c in args.excluded.split('_')] if args.excluded != "" else []
dataset, split = args.dataset, args.split
if '0712' in dataset:
loader = VOCLoader(['07', '12'], 'train', True, subsets=args.subsets,
cats_exclude=excluded, cut_bad_names=False)
elif '12' in dataset:
loader = VOCLoader('12', split, args.small_data,
cats_exclude=excluded,
cut_bad_names=False)
elif '7' in dataset:
loader = VOCLoader('07', split, cats_exclude=excluded,
cut_bad_names=False)
sampler = InstanceSampler(loader=loader,
n_neighborhoods=args.n_neighborhoods)
suff = '_small' if args.small_data else ''
context_folder = os.path.join(RAW_CONTEXT_DIR, args.run_name + '-' + dataset
+ split + suff + '-%dneib' % args.n_neighborhoods)
if not os.path.exists(context_folder):
os.makedirs(context_folder)
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
log_device_placement=False)) as sess:
estimator = InferenceModel(sess, net, sampler, args.image_size)
estimator.restore_from_ckpt(args.ckpt)
bar = progressbar.ProgressBar()
for name in bar(loader.filenames):
save_file = os.path.join(context_folder, name)
if os.path.exists(save_file):
continue
sample = sampler.get_test_sample(name, args.test_n)
imgs, bboxes, frames, ws, hs = sample2batch(sample)
probs, bboxes_out = estimator.estimate_context(imgs, bboxes,
frames, ws, hs)
context_dict = {'bboxes': bboxes_out, 'probs': probs}
np.save(save_file, context_dict)
print('DONE')
if __name__ == '__main__':
tf.app.run()