-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference_from_pb_file-SS.py
138 lines (93 loc) · 3.3 KB
/
inference_from_pb_file-SS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
# coding: utf-8
# ### Loading the converted tensor RT pb graph
# In[1]:
import tensorflow as tf
from tensorflow.python.platform import gfile
import gi
gi.require_version('Gtk', '3.0')
#GRAPH_PB_PATH = './trained_models_local/saved_for_lab/tf_model_base_1502.pb'
#GRAPH_PB_PATH = './converted_trt_graph/trt_graph_base_30.pb'
#GRAPH_PB_PATH = './converted_trt_graph/trt_graph_st_prg_1601_80p.pb'
#GRAPH_PB_PATH_TRT = './converted_trt_graph/trt_graph_ss_model.pb'
GRAPH_PB_PATH_FROZEN_SS='./trt_graph_ss_model.pb'
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(GRAPH_PB_PATH_FROZEN_SS,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
names = []
for t in graph_nodes:
names.append(t.name)
# print operations
#print(names)
# In[ ]:
import tensorflow as tf
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
tf.config.experimental.set_memory_growth(physical_devices[0], True)
print('done')
# ### Importing the graph
# In[ ]:
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = True
tf_sess = tf.Session(config=tf_config)
tf.import_graph_def(graph_def, name='')
# ### loading the first and last layers
# In[ ]:
tf_input = tf_sess.graph.get_tensor_by_name('input_1:0')
print(tf_input)
tf_predictions = tf_sess.graph.get_tensor_by_name('sigmoid/Sigmoid:0')
print(tf_predictions)
# ### Real time prediction of the mask from the camera
# In[ ]:
import cv2
import numpy as np
#import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
from tensorflow.python.keras.backend import set_session
graph = tf.get_default_graph()
#Capture the video from the camera
cap = cv2.VideoCapture(0)
#cap.set(cv2.CAP_PROP_FRAME_WIDTH, 480)
#cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 300)
# For streams:
# cap = cv2.VideoCapture('rtsp://url.to.stream/media.amqp')
# Or e.g. most common ID for webcams:
# cap = cv2.VideoCapture(0)
while cap.isOpened():
ret, frame = cap.read()
#frame2 = frame.reshape((300,480))
image_resized3 = cv2.resize(frame, (480,320))
#print(resized.shape)
#frame2 = np.expand_dims(resized, axis=0)
#Run the Detections using model.predict
if ret:
t0 = time.time()
with graph.as_default():
set_session(sess)
inputs, predictions = tf_sess.run([tf_input, tf_predictions], feed_dict={
tf_input: image_resized3[None, ...]
})
#cv2.imwrite('file5.jpeg', 255*predictions.squeeze())
pred_image = 255*predictions.squeeze()
##converts pred_image to CV_8UC1 format so that ColorMap can be applied on it
u8 = pred_image.astype(np.uint8)
#Color map autumn is applied to the CV_8UC1 pred_image
im_color = cv2.applyColorMap(u8, cv2.COLORMAP_AUTUMN)
cv2.imshow('input image', image_resized3)
cv2.imshow('prediction mask',im_color)
t1 = time.time()
#print('Runtime: %f seconds' % (float(t1 - t0)))
#cv2.waitKey(0)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
else:
cap.release()
break
cap.release()
cv2.destroyAllWindows()