-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo3_reorderingExpt_2D.m
433 lines (332 loc) · 13.4 KB
/
demo3_reorderingExpt_2D.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
% demo3_reorderingExpt_2D.m
% Demo script for running a dataset re-ordering analysis.
% Dataset is generated under a hidden model in this demo,
% but can be easily replaced by a user dataset in matching format.
% Related to Fig 8 in Bak & Pillow 2018.
% Copyright (c) 2018 Ji Hyun Bak
%% initialize
clear all; % clear workspace
clc; % clear command window
setpaths; % add toolbox to path
%% Load a dataset to re-order
% === Generate with 2D stimulus & 4-alternatives responses
% Here we generate a dataset under a "hidden" true model;
% to test the program with another dataset, simply replace this chunk
% and provide your own {xdata, ydata, xx0} in matching formats.
% Golden rule: the i-th row in {xdata,ydata} corresponds to the i-th trial.
Ngen = 500; % We will generate 500 trials
[xdata,ydata,~,xx0,yy0] = auxFun_gendat_x2y4(Ngen);
% xdata: the list of stimuli presented
% (each row is a stimulus vector presented in a trial)
% ydata: the list of response categories observed (each row is a trial)
% xdata and ydata should match by the rows (trials).
% xx0: the original stimulus space, defined by the experimenter.
% again each row should be a (distinct) stimulus vector.
% yy0: the response space (list of all available response categories).
% === Unpack the dataset ====================================
% Detect stimulus/response dimensions
ydim0 = numel(yy0)-1; % minus 1 for choice probability normalization
gdim0 = numel(gfun(xx0(1,:))); % assume that gfun is known
dims = struct('y',ydim0,'g',gdim0);
% Sort and index the stimuli in dataset
xdata_temp = [xdata;xx0]; % append full stimulus set xx0 to xdata for now,
% to take care of the case where some of the
% stimuli in xx0 were not used in xdata
myuniqrows_flip = @(A) unique(fliplr(A),'rows'); % column orders are flipped
[xx_flip,~,idata_temp] = myuniqrows_flip(xdata_temp);
xx = fliplr(xx_flip); % so that first column increases first
idata = idata_temp(1:size(xdata,1)); % now cut away the appended rows
% xx: the stimulus space;
% should be equal to xx0 up to sorting.
% idata: indexing of xdata to the stimulus set xx,
% such that xdata = xx(idata,:).
% Pack data in the input format
mydat = struct('x',xdata,'y',ydata,'i',idata,'xx',xx);
%% Set options for the algorithm
% === Specify algorithms to use =====================
% Select inference method (either MAP or MCMC)
nsamples = 500; % MCMC chain length
doingMAP = (nsamples==0); % if 0, we make MAP estimate;
doingMCMC = (nsamples>0); % if >0, we run MCMC sampling.
% Set prior hyperpameters
hyperprs = ...
struct('wgtmean',0,'wgtsigma',3,... % gaussian prior for weights
'lpsLB',log(0.001),'lpsUB',0, ... % range constraints for lapses
'lpsInit',-5 ... % starting point for lapse parameters
);
% Set whether to include lapses in the model being inferred
withLapse = 0; % 1: use lapse-aware model, 0: ignore lapse
% Unpack parameter dimensions
ydim = dims.y;
gdim = dims.g;
if(withLapse==1)
udim = ydim+1; % lapse parameter length (equal to # choices)
else
udim = 0; % no lapse parameter
end
% a human-readable string to remind what we do in this simulation
inftagList = {'Laplace',['MCMC',num2str(nsamples)]};
inftag = inftagList{doingMCMC+1}; % if not MCMC, then MAP
algtag = 'infomax';
withLapseTagList = {'lapse-unaware','lapse-aware'};
lpstag = withLapseTagList{withLapse+1};
runTag = [algtag,'-',inftag];
% === Pack options for sequential experiment =======
optBase = [];
% parameter initialization, bounds and step sizes
K0 = ydim*gdim;
prsInit = [(hyperprs.wgtmean)*ones(K0,1); ...
(hyperprs.lpsInit)*ones(udim,1)]; % initial value for parameters
optBase.prs0 = prsInit(:)';
optBase.prsInit = prsInit(:)'; % duplicate for re-initialization
optBase.prsLB = [-Inf*ones(K0,1); (hyperprs.lpsLB)*ones(udim,1)]'; % lower bound
optBase.prsUB = [Inf*ones(K0,1); (hyperprs.lpsUB)*ones(udim,1)]'; % upper bound
optBase.steps = ones(1,numel(prsInit)); % initial step sizes
% numbers for MCMC samplings
optBase.nsamples = nsamples; % length of chain
optBase.nburn = 500; % # samples for "burn-in"
optBase.nburnInit = 500; % duplicate for re-initialization
optBase.nburnAdd = 50; % burn-in for additional runs
% more options
optBase.prior = hyperprs;
optBase.reportMoreValues = false;
optBase.talkative = 1; % display level
%% All-data fit, to get a "best" estimate
% Because we do not know the true PF, get a best estimate of the PF
% using the entire dataset (which is supposed to be large).
optAllFit = optBase; % reset options for all-data fit
fprintf('\nRunning inference with all %d trials...\n',Ngen);
if(doingMAP)
[probTrue,theta,~,~,~] = fun_BASS_MAP(xx,mydat,dims,optAllFit);
elseif(doingMCMC)
numFullRuns = 5; % iterate a few times to adjust step sizes
for nfr = 1:numFullRuns
[probTrue,theta,~,~,chainLmat,~] = ...
fun_BASS_MCMC(xx,mydat,dims,optAllFit);
% update sampling parameters
optAllFit.prs0 = theta;
optAllFit.steps = chainLmat;
end
end
% "True" parameters (in this case, approximated by the best estimate
% using all trials in dataset)
paramBest_AllDataFit = paramVec2Struct(theta,dims);
fprintf('\nBest estimate using all %d trials:\n',Ngen);
fprintf('--------------------------------------\n');
fprintf(' biases: b =%4.1f %4.1f %4.1f\n', paramBest_AllDataFit.b);
fprintf('weights for x1: a[1] =%4.1f %4.1f %4.1f\n', paramBest_AllDataFit.a(1,:));
fprintf('weights for x2: a[2] =%4.1f %4.1f %4.1f\n', paramBest_AllDataFit.a(2,:));
if(withLapse)
lapseBest = sum(getLapseProb(paramBest_AllDataFit.u)); % u is auxiliary lapse parameter
fprintf(' lapse rate: lapse =%5.2f\n\n', lapseBest);
end
% === Plot best PF regimes on stimulus space ===
% Response map color
mymap = [ 0 0.4470 0.7410 ;...
0.8500 0.3250 0.0980 ;...
0.9290 0.6940 0.1250 ;...
0.4660 0.6740 0.1880];
mymap = (1-mymap)*0.5 + mymap;
% Determine which response dominates (has the highest probability)
% given each stimulus value.
[~,mostLikelyResponse] = max(probTrue,[],1); % based on the true choice probability
xx1D = unique(xx(:,1)); % 1D grid, just for plotting (in this case stimulus grid is symmetric)
clf; subplot(2,2,1)
imagesc(xx1D,xx1D,reshape(mostLikelyResponse,numel(xx1D)*[1 1])')
set(gca,'YDir','normal')
colormap(mymap)
hold on
plot(xx(:,1),xx(:,2),'k.') % stimulus space
hold off
axis square
xlabel('stim 1')
ylabel('stim 2')
title('choice regimes, all-data estimate')
% === Plot best PF surfaces ===
subplot(2,2,2)
surf(xx1D,xx1D,reshape(probTrue(1,:),numel(xx1D)*[1 1])','FaceColor',mymap(1,:))
hold on
surf(xx1D,xx1D,reshape(probTrue(2,:),numel(xx1D)*[1 1])','FaceColor',mymap(2,:))
surf(xx1D,xx1D,reshape(probTrue(3,:),numel(xx1D)*[1 1])','FaceColor',mymap(3,:))
surf(xx1D,xx1D,reshape(probTrue(4,:),numel(xx1D)*[1 1])','FaceColor',mymap(4,:))
hold off
axis tight
axis square
zlim([0 1])
xlabel('stim 1')
ylabel('stim 2')
zlabel('P(choice)')
title('PF, all-data estimate')
set(get(gca,'xlabel'),'rotation',25);
set(get(gca,'ylabel'),'rotation',-30);
set(gca,'LineWidth',1.5)
%% Choose some initial stimulus-response observations
% pick initial points
ninit = 10;
iinit = randsample(1:size(xdata,1),ninit,false); % choose random stimuli
xinit = xdata(iinit,:);
yinit = ydata(iinit);
% initialize dataset
seqdat = struct('x',xinit,'y',yinit,'i',iinit(:));
% plot initial stimuli
subplot(2,2,1)
hold on
for np = 1:ninit
pointcolor = mymap(yinit(np)+1,:); % color indicates observed response
plot(xinit(np,1),xinit(np,2),'ko','markerfacecolor',pointcolor) % initial stimuli
end
% === Posterior inference with initial data =======
optSeq = optBase; % reset options for sequencial experiment
if(doingMAP)
% MAP estimate
[probEst,prmEst,infoCrit,covEntropy,~] = ...
fun_BASS_MAP(xx,seqdat,dims,optSeq);
elseif(doingMCMC)
% MCMC sampling
[probEst,prmEst,infoCrit,covEntropy,chainLmat,~] = ...
fun_BASS_MCMC(xx,seqdat,dims,optSeq);
% adjust next sampling parameters
optSeq.prs0 = prmEst;
optSeq.steps = chainLmat;
optSeq.nburn = optSeq.nburnAdd; % shorter burn-in for non-initial trials
% track sampler properties
chainstd = diag(chainLmat)'; % store diagonal part (std) only
end
% detect choice regimes / decision boundaries
[~,mostLikelyResponse] = max(probEst,[],1); % from estimated PF
% plot estimated PF regimes
subplot(2,2,2)
imagesc(xx1D,xx1D,reshape(mostLikelyResponse,numel(xx1D)*[1 1])')
set(gca,'YDir','normal')
colormap(mymap)
hold on
plot(xx(:,1),xx(:,2),'k.') % stimulus space
hold off
axis square
xlabel('stim 1')
ylabel('stim 2')
title('estimated choice regimes')
% === Select next stimulus using infomax =========
% get indexing right: choose from remaining stimulus-response pairs
remains = find(~ismember(1:size(xdata,1),seqdat.i)); % index for xdata
% infoMax score (1 if max info gain; 0 otherwise)
myscore = double(infoCrit==max(infoCrit(idata(remains)))); % infoCrit matches xx
myscoreR = myscore(idata(remains)); % idata(remains): index for xx
idxR = randsample(numel(remains),1,true,myscoreR); % idxR: index for remains
% pull next trial
idxnext = remains(idxR);
xnext = xdata(idxnext,:);
ynext = ydata(idxnext);
% plot expected information gain
subplot(2,2,3)
imagesc(xx1D,xx1D,reshape(infoCrit,numel(xx1D)*[1 1])')
set(gca,'YDir','normal')
axis square
colormap(gca,'gray')
hold on
plot(xnext(1),xnext(2),'r*')
hold off
xlabel('stim 1')
ylabel('stim 2')
title('expected info gain')
%% Adaptively sample and add one stimulus at a time
N = 50; % total # trials in the experiment
% track performance measures
MSE = NaN(N,1); % mean-square error
postEnt = NaN(N,1); % (approximate) posterior covariance entropy
% fill in result with initial data
MSE(ninit) = mean(sum((probTrue'-probEst').^2,2),1);
postEnt(ninit) = covEntropy;
% sequential experiment
for jj=(ninit+1):N
disp(['Trial #',num2str(jj)]);
% add to dataset
seqdat.x(end+1,:) = xnext;
seqdat.y(end+1) = ynext;
seqdat.i(end+1) = idxnext;
% posterior inference
if(doingMAP)
% MAP estimate
[probEst,prmEst,infoCrit,covEntropy,~] = ...
fun_BASS_MAP(xx,seqdat,dims,optSeq);
elseif(doingMCMC)
% MCMC sampling
[probEst,prmEst,infoCrit,covEntropy,chainLmat,~] = ...
fun_BASS_MCMC(xx,seqdat,dims,optSeq);
% adjust next sampling parameters
optSeq.prs0 = prmEst;
optSeq.steps = chainLmat;
optSeq.nburn = optSeq.nburnAdd; % shorter burn-in for non-initial trials
% track sampler properties
chainstd = diag(chainLmat)'; % store diagonal part (std) only
end
% detect choice regimes / decision boundaries
[~,mostLikelyResponse] = max(probEst,[],1); % from estimated PF
% performance measure
MSE(jj) = mean(sum((probTrue'-probEst').^2,2),1); % mean-square error
postEnt(jj) = covEntropy; % (approximate) covariance entropy
% === Select next stimulus using infomax =========
% get indexing right: choose from remaining stimulus-response pairs
remains = find(~ismember(1:size(xdata,1),seqdat.i)); % index for xdata
% infoMax score (1 if max info gain; 0 otherwise)
myscore = double(infoCrit==max(infoCrit(idata(remains)))); % infoCrit matches xx
myscoreR = myscore(idata(remains)); % idata(remains): index for xx
idxR = randsample(numel(remains),1,true,myscoreR); % idxR: index for remains
% pull next trial
idxnext = remains(idxR);
xnext = xdata(idxnext,:);
ynext = ydata(idxnext);
% === Plot ========================================
% plot selected stimuli
subplot(2,2,1)
hold on
plot(seqdat.x(1:jj-1,1),seqdat.x(1:jj-1,2),'ko','markerfacecolor','k') % previous trials
pointcolor = mymap(seqdat.y(jj)+1,:); % color indicates observed response
plot(seqdat.x(jj,1),seqdat.x(jj,2),'ko','markerfacecolor',pointcolor) % current trial
% plot estimated PF regimes
subplot(2,2,2)
imagesc(xx1D,xx1D,reshape(mostLikelyResponse,numel(xx1D)*[1 1])')
set(gca,'YDir','normal')
colormap(mymap)
hold on
plot(xx(:,1),xx(:,2),'k.') % stimulus space
hold off
axis square
xlabel('stim 1')
ylabel('stim 2')
title('estimated choice regimes')
% plot expected information gain
subplot(2,2,3)
imagesc(xx1D,xx1D,reshape(infoCrit,numel(xx1D)*[1 1])')
set(gca,'YDir','normal')
axis square
colormap(gca,'gray')
hold on
plot(xnext(1),xnext(2),'r*')
hold off
xlabel('stim 1')
ylabel('stim 2')
title('expected info gain')
subplot(4,2,6)
plot(ninit:jj,postEnt(ninit:jj),'k.-')
xlim([ninit N])
ylabel('post. ent.')
subplot(4,2,8)
plot(ninit:jj,MSE(ninit:jj),'k.-')
xlim([ninit N])
ylabel('error')
xlabel('# trials')
drawnow;
% =================================================
end
% Final estimate after the last trial in sequence
paramEst = paramVec2Struct(prmEst,dims); % get a param struct
fprintf('\nEstimated after %d trials:\n',N);
fprintf('--------------------------------------\n');
fprintf(' biases: b =%4.1f %4.1f %4.1f\n', paramEst.b);
fprintf('weights for x1: a[1] =%4.1f %4.1f %4.1f\n', paramEst.a(1,:));
fprintf('weights for x2: a[2] =%4.1f %4.1f %4.1f\n', paramEst.a(2,:));
if(withLapse)
lapseEst = sum(getLapseProb(paramEst.u)); % u is auxiliary lapse parameter
fprintf(' lapse rate: lapse =%5.2f\n\n', lapseEst);
end