1
+ %% Demo for interactive video segmentation
2
+
3
+
4
+ %% configs
5
+ addpath caffe/matlab
6
+ addpath utils
7
+
8
+ im_path = ' data/demo_data/horsejump-high' ;
9
+ cache_dir = ' data/cache' ;
10
+
11
+ zoom_ratio = 4 ;
12
+ opaque_ratio = 0.5 ;
13
+
14
+ net_prototxt = ' models/interactive/deploy.prototxt' ;
15
+ net_weight = ' models/interactive/test.caffemodel' ;
16
+ gpu_id = 1 ;
17
+
18
+ %% compute the embedding vectors
19
+ caffe .reset_all();
20
+ caffe .set_mode_gpu();
21
+ caffe .set_device(gpu_id );
22
+ net = caffe .Net(net_prototxt ,net_weight ,' test' );
23
+
24
+ mkdir(cache_dir );
25
+ im_names = dir(fullfile(im_path ,' *.jpg' ));
26
+ im_names = {im_names .name }; im_names = cellfun(@(x ) strrep(x ,' .jpg' ,' ' ),im_names ,' UniformOutput' ,false );
27
+
28
+ for i_im = 1 : numel(im_names )
29
+ img = imread(fullfile(im_path ,[im_names{i_im } ' .jpg' ]));
30
+ img_resize = imresize(img ,[478 958 ]);
31
+ img_pp = pre_processing(img_resize ); net_input = img_pp ;
32
+ net .blobs(net.inputs{1 }).reshape([size(net_input ) 1 ]);
33
+ net .blobs(net.inputs{1 }).set_data(net_input );
34
+ net .forward_prefilled();
35
+ net_output = net .blobs(net.outputs{1 }).get_data();
36
+ net_output = permute(net_output ,[2 ,1 ,3 ]);
37
+
38
+ save((fullfile(cache_dir ,[im_names{i_im } ' .mat' ])),' net_output' );
39
+ end
40
+
41
+ %% load embedding vectors of a subset of images
42
+ all_img = cell(4 ,4 );
43
+ all_feat = cell(4 ,4 );
44
+ for i_im = 1 : 16
45
+ frame_id = round(i_im * numel(im_names )/16 );
46
+ im_name = im_names{frame_id };
47
+ load((fullfile(cache_dir ,[im_name ' .mat' ])),' net_output' );
48
+ img = imread(fullfile(im_path ,[im_name ' .jpg' ]));
49
+ img = imresize(img ,zoom_ratio * size(net_output(: ,: ,1 )));
50
+
51
+ all_img{i_im } = img ;
52
+ all_feat{i_im } = net_output ;
53
+ end
54
+ all_img = all_img ' ; all_feat = all_feat ' ;
55
+
56
+ all_img = cell2mat(all_img );
57
+ all_feat = cell2mat(all_feat );
58
+ all_feat = reshape(all_feat ,numel(all_feat(: ,: ,1 )),[]);
59
+
60
+ %% interaction with users
61
+
62
+ D_max = zeros(size(all_feat ,1 ),1 );
63
+ lb_arr = ones(size(all_feat ,1 ),1 );
64
+ plot_img = uint8((1 - opaque_ratio )*all_img );
65
+ obj_id = 1 ;
66
+ indx_col = [];
67
+ point_lb_col = [];
68
+ color_mask = zeros(size(all_img(: ,: ,1 )));
69
+ close all
70
+
71
+ fprintf(' \n\n Manual:\n left click: select object\n right click: select background\n z:undo \n r:clear\n numbers:switch object id (for multiple objects)\n Have fun!:)\n\n ' )
72
+ while (1 )
73
+ imshow(plot_img ); hold on ;
74
+ [x ,y ,button ] = ginput(1 );
75
+
76
+ if (isequal(button ,1 ))
77
+ point_lb = obj_id ;
78
+ elseif (isequal(button ,3 ))
79
+ point_lb = 0 ;
80
+ elseif (isequal(button ,122 )) % z: undo
81
+ D_max = D_max_prev ;
82
+ lb_arr = lb_arr_prev ;
83
+ color_mask = color_mask_prev ;
84
+ indx_col = indx_col_prev ;
85
+ point_lb_col = point_lb_col_prev ;
86
+ plot_img = uint8((1 - opaque_ratio )*all_img + 2 * opaque_ratio * uint8(color_mask ));
87
+ continue ;
88
+ elseif (isequal(button ,114 )) % r: clear all clicks
89
+ D_max = zeros(size(all_feat ,1 ),1 );
90
+ lb_arr = ones(size(all_feat ,1 ),1 );
91
+ plot_img = uint8((1 - opaque_ratio )*all_img );
92
+ obj_id = 1 ;
93
+ indx_col = [];
94
+ point_lb_col = [];
95
+ color_mask = zeros(size(all_img(: ,: ,1 )));
96
+ continue ;
97
+ else
98
+ obj_id = button - 48 ;
99
+ continue ;
100
+ end
101
+
102
+ D_max_prev = D_max ;
103
+ color_mask_prev = color_mask ;
104
+ indx_col_prev = indx_col ;
105
+ point_lb_col_prev = point_lb_col ;
106
+ lb_arr_prev = lb_arr ;
107
+
108
+ x = round(x / zoom_ratio ); y = round(y / zoom_ratio );
109
+ indx = size(plot_img ,1 )/zoom_ratio *(x - 1 ) + y ;
110
+
111
+ indx_col = [indx_col ;indx ];
112
+ point_lb_col = [point_lb_col ;point_lb ];
113
+
114
+ D = (1 - pdist2(all_feat(: ,1 : (end - 3 )),all_feat(indx ,1 : (end - 3 )),' cosine' ));
115
+ [D_max ,max_id ] = max([D_max ,D ],[],2 );
116
+ lb_arr(max_id == 2 ) = point_lb ;
117
+
118
+ mask = reshape(lb_arr ,size(all_img ,1 )/zoom_ratio ,size(all_img ,2 )/zoom_ratio );
119
+
120
+ mask = imresize(mask ,zoom_ratio ,' nearest' );
121
+ color_mask = vis_color(mask ,0 );
122
+
123
+ plot_img = uint8((1 - opaque_ratio )*all_img + 2 * opaque_ratio * uint8(color_mask ));
124
+ end
0 commit comments