Skip to content

Commit eefab78

Browse files
committed
interactive part
1 parent 26a21ca commit eefab78

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+7310
-0
lines changed
122 KB
123 KB
119 KB
118 KB
119 KB
114 KB
116 KB
117 KB
112 KB
109 KB
109 KB
105 KB
103 KB
105 KB
105 KB
110 KB
111 KB
109 KB
111 KB
112 KB
109 KB
111 KB
112 KB
110 KB
109 KB
112 KB
110 KB
110 KB
112 KB
111 KB
113 KB
115 KB
112 KB
114 KB
116 KB
115 KB
115 KB
115 KB
115 KB
116 KB
117 KB
115 KB
114 KB
115 KB
113 KB
113 KB
115 KB
114 KB
113 KB
114 KB

demo_interactive.m

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
%% Demo for interactive video segmentation
2+
% yuhua chen <[email protected]>
3+
4+
%% configs
5+
addpath caffe/matlab
6+
addpath utils
7+
8+
im_path = 'data/demo_data/horsejump-high';
9+
cache_dir = 'data/cache';
10+
11+
zoom_ratio = 4;
12+
opaque_ratio = 0.5;
13+
14+
net_prototxt = 'models/interactive/deploy.prototxt';
15+
net_weight = 'models/interactive/test.caffemodel';
16+
gpu_id = 1;
17+
18+
%% compute the embedding vectors
19+
caffe.reset_all();
20+
caffe.set_mode_gpu();
21+
caffe.set_device(gpu_id);
22+
net = caffe.Net(net_prototxt,net_weight,'test');
23+
24+
mkdir(cache_dir);
25+
im_names = dir(fullfile(im_path,'*.jpg'));
26+
im_names = {im_names.name}; im_names = cellfun(@(x) strrep(x,'.jpg',''),im_names,'UniformOutput',false);
27+
28+
for i_im = 1:numel(im_names)
29+
img = imread(fullfile(im_path,[im_names{i_im} '.jpg']));
30+
img_resize = imresize(img,[478 958]);
31+
img_pp = pre_processing(img_resize); net_input = img_pp;
32+
net.blobs(net.inputs{1}).reshape([size(net_input) 1]);
33+
net.blobs(net.inputs{1}).set_data(net_input);
34+
net.forward_prefilled();
35+
net_output = net.blobs(net.outputs{1}).get_data();
36+
net_output = permute(net_output,[2,1,3]);
37+
38+
save((fullfile(cache_dir,[im_names{i_im} '.mat'])),'net_output');
39+
end
40+
41+
%% load embedding vectors of a subset of images
42+
all_img = cell(4,4);
43+
all_feat = cell(4,4);
44+
for i_im = 1:16
45+
frame_id = round(i_im*numel(im_names)/16);
46+
im_name = im_names{frame_id};
47+
load((fullfile(cache_dir,[im_name '.mat'])),'net_output');
48+
img = imread(fullfile(im_path,[im_name '.jpg']));
49+
img = imresize(img,zoom_ratio*size(net_output(:,:,1)));
50+
51+
all_img{i_im} = img;
52+
all_feat{i_im} = net_output;
53+
end
54+
all_img = all_img'; all_feat = all_feat';
55+
56+
all_img = cell2mat(all_img);
57+
all_feat = cell2mat(all_feat);
58+
all_feat = reshape(all_feat,numel(all_feat(:,:,1)),[]);
59+
60+
%% interaction with users
61+
62+
D_max = zeros(size(all_feat,1),1);
63+
lb_arr = ones(size(all_feat,1),1);
64+
plot_img = uint8((1-opaque_ratio)*all_img);
65+
obj_id = 1;
66+
indx_col = [];
67+
point_lb_col = [];
68+
color_mask = zeros(size(all_img(:,:,1)));
69+
close all
70+
71+
fprintf('\n\n Manual:\n left click: select object\n right click: select background\n z:undo \n r:clear\n numbers:switch object id (for multiple objects)\n Have fun!:)\n\n')
72+
while(1)
73+
imshow(plot_img); hold on;
74+
[x,y,button] = ginput(1);
75+
76+
if(isequal(button,1))
77+
point_lb = obj_id;
78+
elseif(isequal(button,3))
79+
point_lb = 0;
80+
elseif(isequal(button,122)) % z: undo
81+
D_max = D_max_prev;
82+
lb_arr = lb_arr_prev;
83+
color_mask = color_mask_prev;
84+
indx_col = indx_col_prev;
85+
point_lb_col = point_lb_col_prev;
86+
plot_img = uint8((1-opaque_ratio)*all_img + 2*opaque_ratio*uint8(color_mask));
87+
continue;
88+
elseif(isequal(button,114)) % r: clear all clicks
89+
D_max = zeros(size(all_feat,1),1);
90+
lb_arr = ones(size(all_feat,1),1);
91+
plot_img = uint8((1-opaque_ratio)*all_img);
92+
obj_id = 1;
93+
indx_col = [];
94+
point_lb_col = [];
95+
color_mask = zeros(size(all_img(:,:,1)));
96+
continue;
97+
else
98+
obj_id = button - 48;
99+
continue;
100+
end
101+
102+
D_max_prev = D_max;
103+
color_mask_prev = color_mask;
104+
indx_col_prev = indx_col;
105+
point_lb_col_prev = point_lb_col;
106+
lb_arr_prev = lb_arr;
107+
108+
x = round(x/zoom_ratio); y = round(y/zoom_ratio);
109+
indx = size(plot_img,1)/zoom_ratio*(x-1) + y;
110+
111+
indx_col = [indx_col;indx];
112+
point_lb_col = [point_lb_col;point_lb];
113+
114+
D = (1-pdist2(all_feat(:,1:(end-3)),all_feat(indx,1:(end-3)),'cosine'));
115+
[D_max,max_id] = max([D_max,D],[],2);
116+
lb_arr(max_id==2) = point_lb;
117+
118+
mask = reshape(lb_arr,size(all_img,1)/zoom_ratio,size(all_img,2)/zoom_ratio);
119+
120+
mask = imresize(mask,zoom_ratio,'nearest');
121+
color_mask = vis_color(mask,0);
122+
123+
plot_img = uint8((1-opaque_ratio)*all_img + 2*opaque_ratio*uint8(color_mask));
124+
end

0 commit comments

Comments
 (0)