-
Notifications
You must be signed in to change notification settings - Fork 36
/
autoTrainRFC.m
277 lines (233 loc) · 9.38 KB
/
autoTrainRFC.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
function autoTrainRFC(varargin)
%AUTOTRAINRFC trains a RandomForestClassifier (RFC)
%based on CCDC change detection results and an land cover reference image.
%
% Specific parameters
% ------------------------
% 'SampleImage' A reference image with '.hdr' format. *REQUIRED.
% Note Do not input '.hdr'. For example, if your
% sample file name is 'your_path/sample.hdr', you
% should input 'your_path/sample'.
%
% 'SampleYear' The year of the reference data. *REQUIRED.
% 'SampleNumber' Total number of training samples. Default is
% 20,000
% 'NTrees' Number of trees grown. Default is 500.
% 'CCDCDir' Directory of CCDC change detection results.
% Default is the path to current folder.
%
% For example
% autoTrainRFC('SampleImage','\\Mac\Home\Desktop\CCDCTest\13633','SampleYear' ,2000)
%
% History
% ------------------------
% CCDC 1.3 version - Zhe Zhu, EROS, USGS
%
% Revisions: $ Date: 11/25/2015 $ Copyright: Zhe Zhu
% Version 1.3 Select sample based on best strategy (06/30/2015)
% Version 1.2 Only use undisturbed data (05/11/2015)
% Version 1.1 Use version 7.3 for storing RF model (01/10/2015)
% Version 1.0 Fixed a bug in picking the wrong pixel for training (11/08/2014)
%
% Author: Zhe Zhu (zhe.zhu#ttu.edu)
% Shi Qiu (shi.qiu#ttu.edu)
% Date: 28. Jun, 2018
%% get parameters from inputs
% defaults
% version of CCDC
ccdc_v = 1.3;
% where the all CCDC change detection results are
ccdc_dir = pwd;
% number of trees grown.
ntrees = 500;
% request the user's inputs
p = inputParser;
p.FunctionName = 'trainParas';
% optional
% default values.
addParameter(p,'CCDCDir',ccdc_dir);
addParameter(p,'SampleImage','');
addParameter(p,'SampleYear','');
addParameter(p,'SampleNumber','');
addParameter(p,'NTrees',ntrees);
parse(p,varargin{:});
ccdc_dir = p.Results.CCDCDir;
dir_out = ccdc_dir;
% number of trees grown.
ntrees = p.Results.NTrees;
name_roi = p.Results.SampleImage;
if isempty(name_roi)
fprintf('Please input a reference image''s path\r\n');
end
trn_year = p.Results.SampleYear;
if isempty(trn_year)
fprintf('Please input the year the reference data presents \r\n');
end
num_tot = p.Results.SampleNumber;
if isempty(num_tot)
num_tot = 20000;
fprintf('Total number of samples are %d (default) \r\n',num_tot);
else
fprintf('Total number of samples are %d \r\n',num_tot);
end
%% Prepare for the inputs
% get image parameters automatically
imf=dir(fullfile(ccdc_dir,'L*')); % folder names
[nrows,ncols,nbands] = autoPara(imf);
% 2. ground truth time interval
gt_start = datenum(trn_year,1,1);
gt_end = datenum(trn_year,12,31);
% Constants:
% number of coefficients
num_c = 8;
jiDim = [ncols,nrows];
% use the Trends
im_roi = enviread(name_roi); % Land Cover Trends data (all pixels)
% transformed to Matlab dimension (opposite in x & y)
tim_roi = im_roi';
clear im_roi;
% Find nonzero ids with 1 2 3
% 4 5 6
idsfind = find(tim_roi > 0);
[~,i_ids] = ind2sub(jiDim,idsfind);
% length of roi pixels
rec_l = length(idsfind);
% rec_l = 1000;
% Get Training data prepared
% the training data is NxD and labels are Nx1, where N=#of
% examples, D=# of features
% intialize maximum Xs and Ys
X = zeros(rec_l,(num_c+1)*(nbands-1)); % 7 bands cft & rmse
Y = zeros(rec_l,1); % Trends classes + location
% Get into the TSFitMap folder
name_rst = 'TSFitMap';
tsfitmap_path = fullfile(ccdc_dir,name_rst);
% cd(v_input.name_rst);
% intiate i_row
i_row = -1;
% number of pixels for traning
plusid = 0;
% % for i=1:rec_l
for i=1:rec_l
% Just load once for a line of rec_cg for all reference within this line
if i_ids(i) ~= i_row
fprintf('Processing the %dth row ...\n',i_row);
% load CCDCRec
load(fullfile(tsfitmap_path,['record_change',num2str(i_ids(i))]));
% matrix of each component
% % try % sometimes the line record has no fields.
t_start = [rec_cg.t_start];
t_end = [rec_cg.t_end];
coefs = [rec_cg.coefs];
rmse = [rec_cg.rmse];
pos = [rec_cg.pos];
categ = [rec_cg.category];
% reshape coefs
coefs = reshape(coefs,num_c,nbands-1,[]);
end
% find the curve within a fixed time interval
ids_line = find(pos == idsfind(i));
for j = 1:length(ids_line)
% id of reference data
id_ref = ids_line(j);
% position of reference data
pos_ref = pos(id_ref);
% take curves that fall witin the training period
% remove curves that are changed within training period
if t_start(id_ref) < gt_start & t_end(id_ref) > gt_end
% number of time series model for training
plusid = plusid + 1;
% two way of overall reflectance
tmp_cft = coefs(:,:,id_ref);
% temporal overall ref
% tmp_cft(1,:) = tmp_cft(1,:)+gt_mid*tmp_cft(2,:);
tmp_cft(1,:) = tmp_cft(1,:) + 0.5*(t_start(id_ref)+t_end(id_ref))*tmp_cft(2,:);
% all rmse
tmp_rmse = rmse(:,id_ref);
% prepare Xs
X(plusid,:) = [tmp_rmse;tmp_cft(:)];
Y(plusid) = tim_roi(pos_ref); % class category
end
end
i_row = i_ids(i);
end
% remove out of boundary or changed pixels
X = X(1:plusid,:);
Y = Y(1:plusid);
% make sure they are double!
Y = double(Y);
X = double(X);
% % % cd to the images folder
% % cd(l_dir);
%% selecting pixels for training
% number of variables
x_dim = size(X,2);
% update class number
all_class = unique(Y);
% update number of class
n_class = length(all_class);
% calculate proportion based # for each class
prct = hist(Y(:,1),all_class);
prct = prct/sum(prct);
% number of reference for euqal number training
eq_num = num_tot; % total #
n_min = round(0.03*num_tot); % minimum #
n_max = round(0.4*num_tot); % maximum #
% intialized selected X & Y training data
sel_X_trn = [];
sel_Y_trn = [];
for i_class = 1:n_class
% find ids for each land cover class
ids = find(Y == all_class(i_class));
% total # of reference pixels for permute
tmp_N = length(ids);
% random permute the reference pixels
tmp_rv = randperm(tmp_N);
% adjust num_prop based on proportion
adj_num = ceil(eq_num*prct(i_class));
% adjust num_prop based on min and max
if adj_num < n_min
adj_num = n_min;
elseif adj_num > n_max
adj_num = n_max;
end
if tmp_N > adj_num
% if tmp_N > adj_num, use adj_num, otherwise, use tmp_N
tot_n = adj_num;
else
tot_n = tmp_N;
end
% permutted ids
rnd_ids = ids(tmp_rv(1:tot_n));
% X_trn and Y_trn
sel_X_trn = [sel_X_trn; X(rnd_ids,:)];
sel_Y_trn = [sel_Y_trn; Y(rnd_ids)];
end
% log for CCDC Train paramters and versions
% report only for the first task
class_value = unique(sel_Y_trn);
class_number = hist(sel_Y_trn,class_value);
fileID = fopen(fullfile(ccdc_dir, 'CCDC_Train_log.txt'),'w');
% write location of image stack
fprintf(fileID,'Image location = %s\r\n',ccdc_dir);
% write number of images used
fprintf(fileID,'Number of sample = %d\r\n',sum(class_number));
% write number of images used
fprintf(fileID,'Minimum number of sample per class = %d\r\n',min(class_number));
% write number of images used
fprintf(fileID,'Maximum number of sample per class = %d\r\n',max(class_number));
% CCDC Version
fprintf(fileID,'CCDC Train Version = %.2f\r\n',ccdc_v);
% updates
fprintf(fileID,'******************************************************************************************************\r\n');
fprintf(fileID,'Revisions: $ Date: 11/20/2015 $ Copyright: Zhe Zhu\r\n');
fprintf(fileID,'Version 1.3 Select sample based on best strategy (06/30/2015)\r\n');
fprintf(fileID,'Version 1.2 Only use undisturbed data (05/11/2015)\r\n');
fprintf(fileID,'Version 1.1 Use version 7.3 for storing RF model (01/10/2015) \r\n');
fprintf(fileID,'Version 1.0 Fixed a bug in picking the wrong pixel for training (11/08/2014)\r\n');
fprintf(fileID,'******************************************************************************************************\r\n');
fclose(fileID);
modelRF = classRF_train(sel_X_trn,sel_Y_trn,ntrees);
save(fullfile(dir_out,'modelRF'),'modelRF','-v7.3');
fprintf('Finishined Training! The model can be found at %s\r\n',dir_out);
end