-
Notifications
You must be signed in to change notification settings - Fork 8
/
app.py
316 lines (268 loc) · 11.2 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import cv2
import imageio
import numpy as np
import argparse
import os
import open3d as o3d
from point_utils import project_pcd, get_depth_map, unproject_pcd, covisibility_mask, mask_pcd_2d
parser = argparse.ArgumentParser()
parser.add_argument("--image", type=str, default="demo.png")
parser.add_argument("--wo_sam", action="store_true")
parser.add_argument("--save_path", type=str, default="output/")
parser.add_argument("--dataset_path", type=str, default="")
parser.add_argument("--dataset_split", type=str, default="test")
parser.add_argument("--dataset_skip", type=int, default=10)
parser.add_argument("--pcd_path", type=str, default="")
parser.add_argument("--mesh_path", type=str, default="")
args = parser.parse_args()
args.use_sam = not args.wo_sam
cv2.namedWindow("2D Annotator")
vis = None
pnt_w = None
pnt_frame_buffer = []
pnt_frame_mask = None
pnt_mask = None
# Set SAM Predictor
if args.use_sam:
import sys
sys.path.append("./segment-anything")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "segment-anything/checkpoints/sam_vit_h_4b8939.pth"
# sam_checkpoint = "segment-anything/checkpoints/sam_vit_b_01ec64.pth"
sam = sam_model_registry['vit_h'](checkpoint=sam_checkpoint)
sam.to("cuda")
predictor = SamPredictor(sam)
else:
predictor = None
obj_size = 4
obj_mode = 'pcd'
# Set 3D point cloud
if args.pcd_path != "":
pcd_t = o3d.t.io.read_point_cloud(args.pcd_path)
pcd = pcd_t.to_legacy()
pnt_w = np.asarray(pcd.points)
color_ori = np.asarray(pcd.colors).copy()
bbox = pcd.get_axis_aligned_bounding_box()
bbox_size = bbox.get_max_bound() - bbox.get_min_bound()
obj_size = np.max(bbox_size)
obj = pcd
elif args.mesh_path != "": # pcd has higher priority
obj_mode = 'mesh'
mesh_t = o3d.t.io.read_triangle_mesh(args.mesh_path)
mesh = mesh_t.to_legacy()
mesh.compute_vertex_normals()
pnt_w = np.asarray(mesh.vertices)
color_ori = np.asarray(mesh.vertex_colors).copy()
if color_ori.shape[0] == 0:
color_ori = np.ones_like(pnt_w)
bbox = mesh.get_axis_aligned_bounding_box()
bbox_size = bbox.get_max_bound() - bbox.get_min_bound()
obj_size = np.max(bbox_size)
obj = mesh
vis = o3d.visualization.Visualizer()
vis.create_window("3D Visualizer", 800, 800)
vis.add_geometry(obj)
# Initialize the list of keypoints
keypoints = []
# Colors for different modes
colors = [(0, 0, 255), (0, 255, 0)]
pnt_sel_color = np.array([0, 0, 255])
pnt_sel_color_global = np.array([255, 0, 0])
pnt_mask_idx = -1
# Initialize the mode
mode = 1
sel_mode = 1 # 0: single frame, 1: multi frame
depth_frames = []
image_idx = 0
if args.dataset_path != "":
from nerf_synthetic import NeRFSynthetic
data = NeRFSynthetic(args.dataset_path, split=args.dataset_split, testskip=args.dataset_skip)
n_images = len(data)
original_image_rgb, c2w, image_path = data[image_idx]
if pnt_w is not None:
uv_cam, pnt_cam, depth = project_pcd(pnt_w, data.K, c2w)
depth_map, index = get_depth_map(uv_cam, depth, *original_image_rgb.shape[:2], scale=3)
else:
# Load the image
image_path = args.image
# original_image = cv2.imread(image_path)
# original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
original_image_rgb = imageio.imread(image_path)
if original_image_rgb.shape[-1] == 4:
original_image_rgb = original_image_rgb / 255.
original_image_rgb = original_image_rgb[:,:,:3] * original_image_rgb[:,:,3:4] + (1 - original_image_rgb[:,:,3:4])
original_image_rgb = (original_image_rgb.clip(0, 1) * 255).astype(np.uint8)
original_image = cv2.cvtColor(original_image_rgb, cv2.COLOR_RGB2BGR)
if predictor is not None:
predictor.set_image(original_image_rgb)
image = original_image.copy()
logits = None
mask = None
print("Image loaded")
# Mouse callback function
def annotate_keypoints(event, x, y, flags, param):
global keypoints, mode, image, logits, mask
if event == cv2.EVENT_LBUTTONDOWN:
# Add the keypoint and mode to the list
keypoints.append((x, y, mode))
# print("Keypoint added:", (x, y, mode))
if predictor is not None:
# Run SAM
input_point = np.array([pts[:2] for pts in keypoints])
input_label = np.array([pts[2] for pts in keypoints])
masks, scores, logits = predictor.predict(
point_coords=input_point,
point_labels=input_label,
mask_input=logits,
multimask_output=False,
)
mask = masks[0]
color_mask = (np.random.random(3) * 255).astype(np.uint8)
colored_mask = np.repeat(mask[:, :, np.newaxis], 3, axis=2) * color_mask
image = cv2.addWeighted(original_image, 0.5, colored_mask, 0.5, 0)
else:
image = original_image.copy()
# Draw a circle at the keypoint position with the corresponding color
for x, y, m in keypoints:
cv2.circle(image, (x, y), 3, colors[m], -1)
cv2.imshow("2D Annotator", image)
# Initialize the mask transparency
depth_ratio = 100
# Trackbar callback function
def on_trackbar(val):
global depth_ratio
depth_ratio = val
# update_image()
cv2.imshow("2D Annotator", image)
# Create a window and set the mouse callback function
cv2.setMouseCallback("2D Annotator", annotate_keypoints)
# Create a trackbar (slider) to control the depth_ratio of the mask
cv2.createTrackbar("Depth Percentage", "2D Annotator", depth_ratio, 100, on_trackbar)
print("Start annotating keypoints")
while True:
cv2.imshow("2D Annotator", image)
key = cv2.waitKey(1) & 0xFF
# Press 'm' to toggle between modes
if key == ord("m"):
mode = (mode + 1) % 2
# Press 'u' to undo the last keypoint
if key == ord("z"):
if keypoints:
# Remove the last keypoint
keypoints.pop()
# Redraw the keypoints
image = original_image.copy()
for x, y, m in keypoints:
cv2.circle(image, (x, y), 3, colors[m], -1)
cv2.imshow("2D Annotator", image)
# Press 's' to save the mask and keypoints
if key == ord("s"):
image_name = os.path.basename(image_path)
if mask is not None:
os.makedirs(args.save_path, exist_ok=True)
mask_path = os.path.join(args.save_path, image_name)
imageio.imwrite(mask_path + '.png', (mask[..., None]*255).astype(np.uint8))
np.savetxt(mask_path + ".txt", keypoints, fmt="%d")
# Press 'n' to go to the next image
if key == ord("n"):
image_idx = (image_idx + 1) % n_images
if args.dataset_path != "":
original_image_rgb, c2w, image_path = data[image_idx]
original_image = cv2.cvtColor(original_image_rgb, cv2.COLOR_RGB2BGR)
if pnt_w is not None:
uv_cam, pnt_cam, depth = project_pcd(pnt_w, data.K, c2w)
depth_map, index = get_depth_map(uv_cam, depth, *original_image_rgb.shape[:2], scale=3)
if predictor is not None:
predictor.set_image(original_image_rgb)
image = original_image.copy()
keypoints = []
logits = None
mask = None
# Press 'p' to go to the previous image
if key == ord("p"):
image_idx = (image_idx - 1) % n_images
if args.dataset_path != "":
original_image_rgb, c2w, image_path = data[image_idx]
original_image = cv2.cvtColor(original_image_rgb, cv2.COLOR_RGB2BGR)
if pnt_w is not None:
uv_cam, pnt_cam, depth = project_pcd(pnt_w, data.K, c2w)
depth_map, index = get_depth_map(uv_cam, depth, *original_image_rgb.shape[:2], scale=3)
if predictor is not None:
predictor.set_image(original_image_rgb)
image = original_image.copy()
keypoints = []
logits = None
mask = None
# Press 'r' to reset the image
if key == ord("r"):
image = original_image.copy()
keypoints = []
logits = None
mask = None
cv2.imshow("2D Annotator", image)
# Press 'c' to crop the point cloud
if key == ord("c") and pnt_w is not None and mask is not None:
if depth_ratio == 100:
pnt_frame_mask = mask_pcd_2d(uv_cam, mask)[..., None]
else:
depth_thresh = obj_size * depth_ratio / 100
pnt_frame_mask = mask_pcd_2d(uv_cam, mask, 0.5, depth_map, depth, depth_thresh)[..., None]
pnt_mask_idx = image_idx
color = (~pnt_frame_mask) * color_ori + pnt_frame_mask * pnt_sel_color / 225.
if obj_mode == 'pcd':
obj.colors = o3d.utility.Vector3dVector(color)
else:
obj.vertex_colors = o3d.utility.Vector3dVector(color)
vis.update_geometry(obj)
# Press 'u' to union the point cloud
if key == ord("u") and pnt_w is not None and mask is not None:
if pnt_mask is None:
pnt_mask = pnt_frame_mask.copy()
else:
pnt_mask = np.logical_or(pnt_mask, pnt_frame_mask)
color = (~pnt_mask) * color_ori + pnt_mask * pnt_sel_color_global / 225.
if obj_mode == 'pcd':
obj.colors = o3d.utility.Vector3dVector(color)
else:
obj.vertex_colors = o3d.utility.Vector3dVector(color)
vis.update_geometry(obj)
# Press 'x' to intersect the point cloud
if key == ord("x") and pnt_w is not None and mask is not None:
if pnt_mask is None:
pnt_mask = pnt_frame_mask.copy()
else:
pnt_mask = np.logical_and(pnt_mask, pnt_frame_mask)
color = (~pnt_mask) * color_ori + pnt_mask * pnt_sel_color_global / 225.
if obj_mode == 'pcd':
obj.colors = o3d.utility.Vector3dVector(color)
else:
obj.vertex_colors = o3d.utility.Vector3dVector(color)
vis.update_geometry(obj)
# Press 'k' to switch the sel mode
if key == ord("k"):
sel_mode = (sel_mode + 1) % 2
print("sel_mode:", 'single frame' if sel_mode == 0 else 'multi frame')
# Press 'a' to add pnt_frame_mask for multi frame selection
if key == ord("a") and pnt_w is not None and pnt_frame_mask is not None:
pnt_frame_buffer.append((pnt_frame_mask, c2w, depth_map, uv_cam))
print('Add pnt_frame_mask to buffer, buffer size:', len(pnt_frame_buffer))
# Press 'e' to export the masked point cloud
if key == ord("e") and pnt_w is not None and pnt_mask is not None:
if obj_mode == 'pcd': # 'mesh' is not supported yet
pcd_t.point.flags = (pnt_mask * 32).astype(np.int32)
pcd_name = os.path.basename(args.pcd_path)[:-4] + '_mask.ply'
os.makedirs(args.save_path, exist_ok=True)
o3d.t.io.write_point_cloud(os.path.join(args.save_path, pcd_name), pcd_t)
print('Export masked point cloud to', os.path.join(args.save_path, pcd_name))
# Press 'q' to exit
if key == ord("q"):
break
if vis is not None:
vis.poll_events()
vis.update_renderer()
# Close all windows
cv2.destroyAllWindows()
if vis is not None:
vis.destroy_window()
# Print the annotated keypoints
print("Annotated keypoints:", keypoints)