-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathexport.py
331 lines (282 loc) · 12.5 KB
/
export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#!/usr/bin/python
#
# Copyright 2020 Brown Visual Computing Lab / Authors of the accompanying paper Matryodshka #
# Copyright 2018 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This file has been modified by Brown Visual Computing Lab / Authors of the accompanying paper Matryodshka
"""
This script exports model into .pb file, which later get converted to onnx file for TensorRT deployment.
"""
from __future__ import division
import sys
sys.path.append('/usr/lib/python2.7/dist-packages')
import tensorflow as tf
from matryodshka.msi import MSI
from matryodshka.utils import build_matrix
from geometry.sampling import bilinear_wrapper
import numpy as np
import os
flags = tf.app.flags
# Input flags
flags.DEFINE_integer('width', 640, 'Image width')
flags.DEFINE_integer('height', 320, 'Image height')
flags.DEFINE_float('xoffset', 0.0,
'Camera x-offset from first to second image.')
flags.DEFINE_float('yoffset', 0.0,
'Camera y-offset from first to second image.')
flags.DEFINE_float('zoffset', 0.0,
'Camera z-offset from first to second image.')
flags.DEFINE_float('min_depth', 1, 'Minimum scene depth.')
flags.DEFINE_float('max_depth', 100, 'Maximum scene depth.')
flags.DEFINE_integer(
'xshift', 0, 'Horizontal pixel shift for image2 '
'(i.e., difference in x-coordinate of principal point '
'from image2 to image1).')
flags.DEFINE_integer(
'yshift', 0, 'Vertical pixel shift for image2 '
'(i.e., difference in y-coordinate of principal point '
'from image2 to image1).')
flags.DEFINE_string('pose1', '',
('Camera pose for first image (if not identity).'
' Twelve space- or comma-separated floats, forming a 3x4'
' matrix in row-major order.'))
flags.DEFINE_string('pose2', '',
('Pose for second image (if not identity).'
' Twelve space- or comma-separated floats, forming a 3x4'
' matrix in row-major order. If pose2 is specified, then'
' xoffset/yoffset/zoffset flags will be used for rendering'
' output views only.'))
flags.DEFINE_string('remap_ref', '',
('Remap file for reference image.'))
flags.DEFINE_string('remap_src', '',
('Remap file for source image.'))
# Output flags
flags.DEFINE_string('test_outputs', '',
'Which outputs to save. Can concat the following with "_": '
'[src_image, ref_image, tgt_image, psp (for perspective crop), hres_tgt_image, '
' src_output_image, ref_output_image, psv, alphas, blend_weights, rgba_layers].')
# Model flags. Defaults are the model described in the SIGGRAPH 2018 paper. See
# README for more details.
flags.DEFINE_string('model_root', 'checkpoints/',
'Root directory for model checkpoints.')
flags.DEFINE_string('model_name', 'ods-wotemp-elpips-coord',
'Name of the model to use for inference.')
flags.DEFINE_string('which_color_pred', 'blend_psv',
'Color output format: [blend_psv, blend_bg, blend_bg_psv, alpha_only].')
flags.DEFINE_integer('num_psv_planes', 32, 'Number of planes for PSV.')
flags.DEFINE_integer('num_msi_planes', 32, 'Number of msi planes to infer.')
flags.DEFINE_integer('ngf', 64, 'Number of filters.')
flags.DEFINE_string('pb_output','matryodshka','name of the pb file')
# Graph export settings
flags.DEFINE_boolean('clip', False,
'Clip weights by float16 range.')
flags.DEFINE_boolean('flip_y', False,
'Flip y axis in input images')
flags.DEFINE_boolean('flip_channels', False,
'Flip channels in input image')
flags.DEFINE_boolean('rgba', False,
'Is image rgba')
flags.DEFINE_boolean('remap', False,
'Whether or not to remap')
flags.DEFINE_boolean('net_only', False,
'Extract only the network')
flags.DEFINE_boolean('smoothed', False,
'Smooth conv2d transpose ops')
flags.DEFINE_boolean('jitter', False, 'jitter for transform inverse traning.')
# Camera models, input, output, internal MSI representation
flags.DEFINE_string('input_type', 'ODS',
'Input image type. [PP, ODS]')
flags.DEFINE_string('operation', 'train',
'Which operation to perform. [train, export]')
flags.DEFINE_string('supervision', 'tgt', "Images to supervise on. [tgt, ref, src, hrestgt] concatenated with _")
flags.DEFINE_boolean('transform_inverse_reg', False, 'Whether to train with transform-inverse regularization.')
flags.DEFINE_boolean('coord_net', False, 'Whether to append CoordNet during convolution.')
# Set flags
FLAGS = flags.FLAGS
def crop_to_multiple(image, size):
"""Crop image to a multiple of size in height and width."""
# Compute how much we need to remove.
shape = image.get_shape().as_list()
height = shape[0]
width = shape[1]
new_width = width - (width % size)
new_height = height - (height % size)
# Crop amounts. Extra pixel goes on the left side.
left = (width % size) // 2
right = new_width + left
top = (height % size) // 2
bottom = new_height + top
return image[top:bottom, left:right, :]
def process_image(raw, height, width, channels, padx, pady, remap_file):
"""Load an image, pad, and shift it."""
image = tf.reshape(raw, (height, width, channels))
# Extract rgb
if FLAGS.rgba:
image = image[:, :, :3]
# Convert image to float32, 0-1 range
image = tf.image.convert_image_dtype(image, tf.float32)
# Remap image
if FLAGS.remap:
image = remap_image(image, remap_file)
# Flip y
if FLAGS.flip_y:
image = tf.reverse(image, axis=[0])
# Flip channels
if FLAGS.flip_channels:
image = tf.reverse(image, axis=[2])
# Pad
image = tf.pad(image, [[pady, pady], [padx, padx], [0, 0]])
image.set_shape([None, None, 3]) # RGB images have 3 channels.
return image
def remap_image(image, remap_file):
remap_vals = np.load(remap_file)
remap_tensor = tf.expand_dims(tf.convert_to_tensor(remap_vals), 0)
image = tf.expand_dims(image, 0)
return tf.squeeze(bilinear_wrapper(image, remap_tensor))
def pose_from_flag(flag):
if flag:
values = [float(x) for x in flag.replace(',', ' ').split()]
assert len(values) == 12
return [values[0:4], values[4:8], values[8:12], [0.0, 0.0, 0.0, 1.0]]
else:
return [[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0]]
def get_inputs(padx, pady, width, height):
"""Get images, poses and intrinsics in required format."""
inputs = {}
# Process images
channels = 4 if FLAGS.rgba else 3
image1 = tf.placeholder(tf.uint8, (width * height * channels), name='ref_image')
image2 = tf.placeholder(tf.uint8, (width * height * channels), name='src_image')
with tf.name_scope('process_image'):
image1 = process_image(image1, height, width, channels, padx, pady, FLAGS.remap_ref)
image2 = process_image(image2, height, width, channels, padx, pady, FLAGS.remap_src)
# Images pad and crop
shape1_before_crop = tf.shape(image1)
shape2_before_crop = tf.shape(image2)
image1 = crop_to_multiple(image1, 16)
image2 = crop_to_multiple(image2, 16)
shape1_after_crop = tf.shape(image1)
shape2_after_crop = tf.shape(image2)
with tf.control_dependencies([
tf.Assert(
tf.reduce_all(
tf.logical_and(
tf.equal(shape1_before_crop, shape2_before_crop),
tf.equal(shape1_after_crop, shape2_after_crop))), [
'Shape mismatch:', shape1_before_crop, shape2_before_crop,
shape1_after_crop, shape2_after_crop
])
]):
# Add batch dimension (size 1).
image1 = tf.expand_dims(image1, 0)
image2 = tf.expand_dims(image2, 0)
# Poses
pose_one = pose_from_flag(FLAGS.pose1)
pose_two = pose_from_flag(FLAGS.pose2)
with tf.name_scope('build_matrices'):
pose_one = build_matrix(pose_one)
pose_two = build_matrix(pose_two)
pose_one = tf.expand_dims(pose_one, 0)
pose_two = tf.expand_dims(pose_two, 0)
intrinsics = build_matrix([[0.032, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0]])
intrinsics = tf.expand_dims(intrinsics, 0)
# Set inputs
inputs['ref_image'] = image1
inputs['src_image'] = image2
inputs['ref_pose'] = pose_one
inputs['src_pose'] = pose_two
inputs['intrinsics'] = intrinsics
# Second order inputs
inputs['ref_pose_inv'] = tf.matrix_inverse(inputs['ref_pose'], name='ref_pose_inv')
inputs['src_pose_inv'] = tf.matrix_inverse(inputs['src_pose'], name='src_pose_inv')
inputs['intrinsics_inv'] = tf.matrix_inverse(inputs['intrinsics'], name='intrinsics_inv')
raw_hres_tgt_image = None
raw_hres_ref_image = None
raw_hres_src_images = None
inputs['hres_ref'] = raw_hres_ref_image
inputs['hres_src'] = raw_hres_src_images
inputs['hres_tgt'] = raw_hres_tgt_image
return inputs
def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
"""
Freezes the state of a session into a pruned computation graph.
Creates a new computation graph where variable nodes are replaced by
constants taking their current value in the session. The new graph will be
pruned so subgraphs that are not necessary to compute the requested
outputs are removed.
@param session The TensorFlow session to be frozen.
@param keep_var_names A list of variable names that should not be frozen,
or None to freeze all the variables in the graph.
@param output_names Names of the relevant graph outputs.
@param clear_devices Remove the device directives from the graph for better portability.
@return The frozen graph definition.
"""
graph = session.graph
with graph.as_default():
freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
output_names = output_names or []
output_names += [v.op.name for v in tf.global_variables()]
input_graph_def = graph.as_graph_def()
if clear_devices:
for node in input_graph_def.node:
node.device = ""
frozen_graph = tf.graph_util.convert_variables_to_constants(
session, input_graph_def, output_names, freeze_var_names)
return frozen_graph
def main(_):
# Get inputs
pady = 0
padx = 0
inputs = get_inputs(padx, pady, FLAGS.width, FLAGS.height)
# Build the network
model = MSI()
psv_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth,
FLAGS.num_psv_planes)
msi_planes = model.inv_depths(FLAGS.min_depth, FLAGS.max_depth,
FLAGS.num_msi_planes)
outputs = model.infer_msi(
inputs['src_image'], inputs['ref_image'], inputs['hres_src'], inputs['hres_ref'],
inputs['ref_pose'], inputs['src_pose'], inputs['intrinsics'],
FLAGS.which_color_pred, FLAGS.num_msi_planes, psv_planes, FLAGS.test_outputs, ngf=FLAGS.ngf)
# Load weights and save graph
saver = tf.train.Saver([var for var in tf.trainable_variables()])
ckpt_dir = os.path.join(FLAGS.model_root, FLAGS.model_name)
ckpt_file = tf.train.latest_checkpoint(ckpt_dir)
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
with sess.as_default():
saver.restore(sess, ckpt_file)
if FLAGS.clip:
for var in tf.trainable_variables():
tv = sess.graph.get_tensor_by_name(var.name)
nv = sess.run(tv)
if np.amax(nv) > tf.float16.max or np.amin(nv) < tf.float16.min:
print(var.name)
print(np.amin(nv))
clipped_nv = np.clip(nv, tf.float16.min, tf.float16.max)
#clipped_v = np.zeros_like(sess.run(v))
sess.run(tf.assign(tv, clipped_nv))
frozen = freeze_session(sess, output_names=['msi_output'], clear_devices=False)
# Write out
with tf.gfile.GFile('export/%s.pb' % FLAGS.pb_output, "wb") as f:
f.write(frozen.SerializeToString())
if __name__ == '__main__':
tf.app.run()