forked from jamesheald/COIN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
COIN.m
2322 lines (1937 loc) · 108 KB
/
COIN.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
classdef COIN < matlab.mixin.Copyable
% COIN v1.0.0
% Author: James Heald
% PROPERTIES description
% core parameters
% sigma_process_noise standard deviation of process noise
% sigma_sensory_noise standard deviation of sensory noise
% sigma_motor_noise standard deviation of motor noise
% prior_mean_retention prior mean of retention
% prior_precision_retention prior precision (inverse variance) of retention
% prior_precision_drift prior precision (inverse variance) of drift
% gamma_context gamma hyperparameter of the Chinese restaurant franchise for the context transitions
% alpha_context alpha hyperparameter of the Chinese restaurant franchise for the context transitions
% rho_context rho (normalised self-transition) hyperparameter of the Chinese restaurant franchise for the context transitions
% parameters if cues are present
% gamma_cue gamma hyperparameter of the Chinese restaurant franchise for the cue emissions
% alpha_cue alpha hyperparameter of the Chinese restaurant franchise for the cue emissions
% parameters if inferring bias
% infer_bias infer the measurment bias (true) or not (false)
% prior_precision_bias precision (inverse variance) of prior of measurement bias
% paradigm
% perturbations vector of perturbations (use NaN on channel trials)
% cues vector of sensory cues (encode cues as consecutive integers starting from 1)
% stationary_trials trials on which to set the predicted probabilities to the stationary probabilities (e.g. following a working memory task)
% runs
% runs number of runs, each conditioned on a different state feedback sequence
% parallel processing of runs
% max_cores maximum number of CPU cores available (0 implements serial processing of runs)
% model implementation
% particles number of particles
% max_contexts maximum number of contexts that can be instantiated
% measured adaptation data
% adaptation vector of adaptation data (use NaN on trials where adaptation was not measured)
% store
% store variables to store in memory
% plot flags
% plot_state_given_context plot state | context distribution ('predicted state distribution for each context')
% plot_predicted_probabilities plot predicted probabilities
% plot_responsibilities plot responsibilities
% plot_stationary_probabilities plot stationary probabilities
% plot_retention_given_context plot retention | context distribution
% plot_drift_given_context plot drift | context distribution
% plot_bias_given_context plot bias | context distribution
% plot_global_transition_probabilities plot global transition probabilities
% plot_local_transition_probabilities plot local transition probabilities
% plot_global_cue_probabilities plot global cue probabilities
% plot_local_cue_probabilities plot local cue probabilities
% plot_state plot state ('overall predicted state distribution')
% plot_average_state plot average state (mean of 'overall predicted state distribution')
% plot_bias plot bias distribution (average bias distribution across contexts)
% plot_average_bias plot average bias (mean of average bias distribution across contexts)
% plot_state_feedback plot predicted state feedback distribution (average state feedback distribution across contexts)
% plot_explicit_component plot explicit component of learning
% plot_implicit_component plot implicit component of learning
% plot_Kalman_gain_given_cstar1 plot Kalman gain | context with highest responsibility on current trial (cstar1)
% plot_predicted_probability_cstar1 plot predicted probability of context with highest responsibility on current trial (cstar1)
% plot_state_given_cstar1 plot state | context with highest responsibility on current trial (cstar1)
% plot_Kalman_gain_given_cstar2 plot Kalman gain | context with highest predicted probability on next trial (cstar2)
% plot_state_given_cstar2 plot state | context with highest predicted probability on next trial (cstar2)
% plot_predicted_probability_cstar3 plot predicted probability of context with highest predicted probability on current trial (cstar3)
% plot_state_given_cstar3 plot state | context with highest predicted probability on current trial (cstar3)
% plot inputs
% retention_values specify values at which to evaluate p(retention) if plot_retention_given_context == true
% drift_values specify values at which to evaluate p(drift) if plot_drift_given_context == true
% state_values specify values at which to evaluate p(state) if plot_state_given_context == true or plot_state == true
% bias_values specify values at which to evaluate p(bias) if plot_bias_given_context == true or plot_bias == true
% state_feedback_values specify values at which to evaluate p(state feedback) if plot_state_feedback == true
% miscellaneous user data
% user_data any data the user would like to associate with an object of the class
%
% VARIABLES description
% average_state average predicted state (average across contexts and particles)
% bias bias of each context (sample)
% bias_distribution bias distribution (discretised)
% bias_mean mean of the posterior of the bias for each context
% bias_ss_1 sufficient statistic #1 for the bias parameter of each context
% bias_ss_2 sufficient statistic #2 for the bias parameter of each context
% bias_var variance of the posterior of the bias for each context
% C number of instantiated contexts
% context context (sample)
% drift state drift of each context (sample)
% dynamics_covar covariance of the posterior of the retention and drift of each context
% dynamics_mean mean of the posterior of the retention and drift of each context
% dynamics_ss_1 sufficient statistic #1 for the retention and drift parameters of each context
% dynamics_ss_2 sufficient statistic #2 for the retention and drift parameters of each context
% explicit explicit component of learning
% global_cue_posterior parameters of the posterior of the global cue distribution
% global_cue_probabilities global cue distribution (sample)
% global_transition_posterior parameters of the posterior of the global transition distribution
% global_transition_probabilities global transition distribution (sample)
% i_observed indices of observed states
% i_resampled indices of resampled particles
% implicit implicit component of learning
% Kalman_gains Kalman gain for each context
% local_cue_matrix expected local cue probability matrix
% local_transition_matrix expected local context transition probability matrix
% m_context number of tables in restaurant i serving dish j (Chinese restaurant franchise for the context transitions)
% m_cue number of tables in restaurant i serving dish j (Chinese restaurant franchise for the cue emissions)
% motor_noise motor noise
% motor_output average predicted state feedback (average across contexts and particles) a.k.a the motor output
% n_context local context transition counts
% n_cue local cue emission counts
% predicted_probabilities predicted context probabilities (conditioned on the cue)
% prediction_error state feedback prediction error for each context
% previous_context context sampled on the previous trial
% previous_state_filtered_mean mean of the filtered state distribution for each context on the previous trial
% previous_state_filtered_var variance of the filtered state distribution for each context on the previous trial
% previous_x_dynamics samples of the states on the previous trial (to update the sufficient statistics for the retention and drift parameters of each context)
% prior_probabilities prior context probabilities (not conditioned on the cue)
% probability_cue probability of the observed cue for each context
% probability_state_feedback probability of the observed state feedback for each context
% Q number of cues observed
% responsibilities context responsibilities (conditioned on the cue and the state feedback)
% retention state retention factor of each context (sample)
% sensory_noise sensory noise
% state_distribution predicted state distribution (discretised)
% state_feedback_distribution predicted state feedback distribution (discretised)
% state_feedback_mean mean of the predicted state feedback distribution for each context
% state_feedback_var variance of the predicted state feedback distribution for each context
% state_filtered_mean mean of the filtered state distribution for each context
% state_filtered_var variance of the filtered state distribution for each context
% state_mean mean of the predicted state distribution for each context
% state_var variance of the predicted state distribution for each context
% stationary_probabilities stationary context probabilities
% x_bias samples of the states on the current trial (to update the sufficient statistics for the bias parameter of each context)
% x_dynamics samples of the states on the current trial (to update the sufficient statistics for the retention and drift parameters of each context)
properties
% core parameters - values taken from Heald et al. (2020) Table S1 (A)
sigma_process_noise = 0.0089
sigma_sensory_noise = 0.03
sigma_motor_noise = 0.0182
prior_mean_retention = 0.9425
prior_precision_retention = 837.1^2
prior_precision_drift = 1.2227e+3^2
gamma_context = 0.1
alpha_context = 8.955
rho_context = 0.2501
% parameters if cues are present
gamma_cue = 0.1
alpha_cue = 25
% parameters if inferring a bias
infer_bias = false
prior_precision_bias = 70^2
% paradigm
perturbations
cues
stationary_trials
% number of runs
runs = 1
% parallel processing
max_cores = 0
% model implementation
particles = 100
max_contexts = 10
% measured adaptation data
adaptation
% store
store = {'state_feedback','motor_output'}
% plot flags
plot_state_given_context = false
plot_predicted_probabilities = false
plot_responsibilities = false
plot_stationary_probabilities = false
plot_retention_given_context = false
plot_drift_given_context = false
plot_bias_given_context = false
plot_global_transition_probabilities = false
plot_local_transition_probabilities = false
plot_global_cue_probabilities = false
plot_local_cue_probabilities = false
plot_state = false
plot_average_state = false
plot_bias = false
plot_average_bias = false
plot_state_feedback = false
plot_explicit_component = false
plot_implicit_component = false
plot_Kalman_gain_given_cstar1 = false
plot_predicted_probability_cstar1 = false
plot_state_given_cstar1 = false
plot_Kalman_gain_given_cstar2 = false
plot_state_given_cstar2 = false
plot_predicted_probability_cstar3 = false
plot_state_given_cstar3 = false
% plot inputs
retention_values = linspace(0.8,1,500);
drift_values = linspace(-0.1,0.1,500);
state_values = linspace(-1.5,1.5,500);
bias_values = linspace(-1.5,1.5,500);
state_feedback_values = linspace(-1.5,1.5,500);
% user data
user_data
end
methods
function S = simulate_COIN(obj)
if ~isempty(obj.cues)
obj.check_cue_labels;
end
% set the store property based on the plots requested
obj_store = copy(obj);
obj_store = set_store_property_for_plots(obj_store);
% number of trials
T = numel(obj.perturbations);
% preallocate memory
tmp = cell(1,obj.runs);
if isempty(obj.adaptation)
trials = 1:T;
% perform runs
fprintf('Simulating the COIN model.\n')
parfor (run = 1:obj.runs,obj.max_cores)
tmp{run} = obj_store.main_loop(trials).stored;
end
% assign equal weights to all runs
w = ones(1,obj.runs)/obj.runs;
else
if numel(obj.adaptation) ~= numel(obj.perturbations)
error('Property ''adaptation'' should be a vector with one element per trial (use NaN on trials where adaptation was not measured).')
end
% perform runs
% resample runs whenever the effective sample size falls below threshold
% preallocate memory
D_in = cell(1,obj.runs);
D_out = cell(1,obj.runs);
% initialise weights to be uniform
w = ones(1,obj.runs)/obj.runs;
% effective sample size threshold for resampling
ESS_threshold = 0.5*obj.runs;
% trials on which adaptation was measured
adaptation_trials = find(~isnan(obj.adaptation));
% simulate trials inbetween trials on which adaptation was measured
for i = 1:numel(adaptation_trials)
if i == 1
trials = 1:adaptation_trials(i);
fprintf('Simulating the COIN model from trial 1 to trial %d.\n',adaptation_trials(i))
else
trials = adaptation_trials(i-1)+1:adaptation_trials(i);
fprintf('Simulating the COIN model from trial %d to trial %d.\n',adaptation_trials(i-1)+1,adaptation_trials(i))
end
parfor (run = 1:obj.runs,obj.max_cores)
if i == 1
D_out{run} = obj_store.main_loop(trials);
else
D_out{run} = obj_store.main_loop(trials,D_in{run});
end
end
% calculate the log likelihood
log_likelihood = zeros(1,obj.runs);
for run = 1:obj.runs
model_error = D_out{run}.stored.motor_output(adaptation_trials(i)) - obj.adaptation(adaptation_trials(i));
log_likelihood(run) = -(log(2*pi*obj.sigma_motor_noise^2) + (model_error/obj.sigma_motor_noise).^2)/2;
end
% update the weights and normalise
l_w = log_likelihood + log(w);
l_w = l_w - obj.log_sum_exp(l_w');
w = exp(l_w);
% calculate the effective sample size
ESS = 1/(sum(w.^2));
% if the effective sample size falls below ESS_threshold, resample
if ESS < ESS_threshold
fprintf('Effective sample size = %.1f %s resampling runs.\n',ESS,char(8212))
i_resampled = obj.systematic_resampling(w);
for run = 1:obj.runs
D_in{run} = D_out{i_resampled(run)};
end
w = ones(1,obj.runs)/obj.runs;
else
fprintf('Effective sample size = %.1f.\n',ESS)
D_in = D_out;
end
end
if adaptation_trials(end) == T
for run = 1:obj.runs
tmp{run} = D_in{run}.stored;
end
elseif adaptation_trials(end) < T
% simulate to the last trial
fprintf('Simulating the COIN model from trial %d to trial %d.\n',adaptation_trials(end)+1,T)
trials = adaptation_trials(end)+1:T;
parfor (run = 1:obj.runs,obj.max_cores)
tmp{run} = obj_store.main_loop(trials,D_in{run}).stored;
end
end
end
% preallocate memory
S.runs = cell(1,obj.runs);
% assign data to S
for run = 1:obj.runs
S.runs{run} = tmp{run};
end
S.weights = w;
S.properties = obj;
% generate plots
props = properties(obj);
for i = find(contains(props','plot'))
if obj.(props{i})
S.plots = plot_COIN(obj,S);
break
end
end
% delete the raw variables that were stored to generate the plots
field_names = fieldnames(S.runs{1});
for i = 1:length(field_names)
if all(~strcmp(field_names{i},obj.store))
for run = 1:obj.runs
S.runs{run} = rmfield(S.runs{run},field_names{i});
end
end
end
end
function obj = set_store_property_for_plots(obj)
% specify variables that need to be stored for plots
tmp = {};
if obj.plot_state_given_context
tmp = cat(2,tmp,{'state_mean','state_var'});
end
if obj.plot_predicted_probabilities
tmp = cat(2,tmp,'predicted_probabilities');
end
if obj.plot_responsibilities
tmp = cat(2,tmp,'responsibilities');
end
if obj.plot_stationary_probabilities
tmp = cat(2,tmp,'stationary_probabilities');
end
if obj.plot_retention_given_context
tmp = cat(2,tmp,{'dynamics_mean','dynamics_covar'});
end
if obj.plot_drift_given_context
tmp = cat(2,tmp,{'dynamics_mean','dynamics_covar'});
end
if obj.plot_bias_given_context
if obj.infer_bias
tmp = cat(2,tmp,{'bias_mean','bias_var'});
else
error('You must infer the measurement bias parameter to use plot_bias_given_context. Set property ''infer_bias'' to true.')
end
end
if obj.plot_global_transition_probabilities
tmp = cat(2,tmp,'global_transition_posterior');
end
if obj.plot_local_transition_probabilities
tmp = cat(2,tmp,'local_transition_matrix');
end
if obj.plot_local_cue_probabilities
if isempty(obj.cues)
error('An experiment must have sensory cues to use plot_local_cue_probabilities.')
else
tmp = cat(2,tmp,'local_cue_matrix');
end
end
if obj.plot_global_cue_probabilities
if isempty(obj.cues)
error('An experiment must have sensory cues to use plot_global_cue_probabilities.')
else
tmp = cat(2,tmp,'global_cue_posterior');
end
end
if obj.plot_state
tmp = cat(2,tmp,'state_distribution','average_state');
end
if obj.plot_average_state
tmp = cat(2,tmp,'average_state');
end
if obj.plot_bias
if obj.infer_bias
tmp = cat(2,tmp,'bias_distribution','implicit');
else
error('You must infer the measurement bias parameter to use plot_bias. Set property ''infer_bias'' to true.')
end
end
if obj.plot_average_bias
tmp = cat(2,tmp,'implicit');
end
if obj.plot_state_feedback
tmp = cat(2,tmp,'state_feedback_distribution');
end
if obj.plot_explicit_component
tmp = cat(2,tmp,'explicit');
end
if obj.plot_implicit_component
tmp = cat(2,tmp,'implicit');
end
if obj.plot_Kalman_gain_given_cstar1
tmp = cat(2,tmp,'Kalman_gain_given_cstar1');
end
if obj.plot_predicted_probability_cstar1
tmp = cat(2,tmp,'predicted_probability_cstar1');
end
if obj.plot_state_given_cstar1
tmp = cat(2,tmp,'state_given_cstar1');
end
if obj.plot_Kalman_gain_given_cstar2
tmp = cat(2,tmp,'Kalman_gain_given_cstar2');
end
if obj.plot_state_given_cstar2
tmp = cat(2,tmp,'state_given_cstar2');
end
if obj.plot_predicted_probability_cstar3
tmp = cat(2,tmp,'predicted_probability_cstar3');
end
if obj.plot_state_given_cstar3
tmp = cat(2,tmp,'state_given_cstar3');
end
if ~isempty(tmp)
tmp = cat(2,tmp,{'context','i_resampled'});
end
% add strings in tmp to the store property of obj
for i = 1:numel(tmp)
if ~any(strcmp(obj.store,tmp{i}))
obj.store{end+1} = tmp{i};
end
end
end
function objective = objective_COIN(obj)
P = numel(obj); % number of participants
n = sum(~isnan(obj(1).adaptation)); % number of adaptation measurements per participant
adaptation_trials = zeros(n,P);
data = zeros(n,P);
for p = 1:P
if numel(obj(p).adaptation) ~= numel(obj(p).perturbations)
error('Property ''adaptation'' should be a vector with one element per trial (use NaN on trials where adaptation was not measured).')
end
if isrow(obj(p).adaptation)
obj(p).adaptation = obj(p).adaptation';
end
% trials on which adaptation was measured
adaptation_trials(:,p) = find(~isnan(obj(p).adaptation));
% measured adaptation
data(:,p) = obj(p).adaptation(adaptation_trials(:,p));
end
log_likelihood = zeros(obj(1).runs,1);
parfor (run = 1:obj(1).runs,obj(1).max_cores)
model = zeros(n,P);
for p = 1:P
% number of trials
T = numel(obj(p).perturbations);
trials = 1:T;
% model adaptation
model(:,p) = obj(p).main_loop(trials).stored.motor_output(adaptation_trials(:,p));
end
% error between average model adaptation and average measured adaptation
model_error = mean(model-data,2);
% log likelihood (probability of data given parameters)
log_likelihood(run) = sum(-(log(2*pi*obj(1).sigma_motor_noise^2/P) + model_error.^2/(obj(1).sigma_motor_noise.^2/P))/2); % variance scaled by the number of participants
end
% negative of the log of the average likelihood across runs
objective = -(log(1/obj(1).runs) + obj(1).log_sum_exp(log_likelihood));
end
function D = main_loop(obj,trials,varargin)
if trials(1) == 1
D = obj.initialise_COIN; % initialise the model
else
D = varargin{1};
end
for trial = trials
D.t = trial; % set the current trial number
D = obj.predict_context(D); % predict the context
D = obj.predict_states(D); % predict the states
D = obj.predict_state_feedback(D); % predict the state feedback
D = obj.resample_particles(D); % resample particles
D = obj.sample_context(D); % sample the context
D = obj.update_belief_about_states(D); % update the belief about the states given state feedback
D = obj.sample_states(D); % sample the states
D = obj.update_sufficient_statistics_for_parameters(D); % update the sufficient statistics for the parameters
D = obj.sample_parameters(D); % sample the parameters
D = obj.store_variables(D); % store variables for analysis if desired
end
end
function D = initialise_COIN(obj)
% number of trials
D.T = numel(obj.perturbations);
% is state feedback observed or not
D.feedback_observed = ones(1,D.T);
D.feedback_observed(isnan(obj.perturbations)) = 0;
% self-transition bias
D.kappa = obj.alpha_context*obj.rho_context/(1-obj.rho_context);
% observation noise standard deviation
D.sigma_observation_noise = sqrt(obj.sigma_sensory_noise^2 + obj.sigma_motor_noise^2);
% matrix of context-dependent observation vectors
D.H = eye(obj.max_contexts+1);
% current trial
D.t = 0;
% number of contexts instantiated so far
D.C = zeros(1,obj.particles);
% context transition counts
D.n_context = zeros(obj.max_contexts+1,obj.max_contexts+1,obj.particles);
% sampled context
D.context = ones(1,obj.particles); % treat trial 1 as a (context 1) self transition
% do cues exist?
if isempty(obj.cues)
D.cuesExist = 0;
else
D.cuesExist = 1;
% number of contextual cues observed so far
D.Q = 0;
% cue emission counts
D.n_cue = zeros(obj.max_contexts+1,max(obj.cues)+1,obj.particles);
end
% sufficient statistics for the parameters of the state dynamics
% function
D.dynamics_ss_1 = zeros(obj.max_contexts+1,obj.particles,2);
D.dynamics_ss_2 = zeros(obj.max_contexts+1,obj.particles,2,2);
% sufficient statistics for the parameters of the observation function
D.bias_ss_1 = zeros(obj.max_contexts+1,obj.particles);
D.bias_ss_2 = zeros(obj.max_contexts+1,obj.particles);
% sample parameters from the prior
D = sample_parameters(obj,D);
% mean and variance of state (stationary distribution)
D.state_filtered_mean = D.drift./(1-D.retention);
D.state_filtered_var = obj.sigma_process_noise^2./(1-D.retention.^2);
end
function D = predict_context(obj,D)
if ismember(D.t,obj.stationary_trials)
% if some event (e.g. a working memory task) causes the context
% probabilities to be erased, set them to their stationary values
for particle = 1:obj.particles
C = sum(D.local_transition_matrix(:,1,particle)>0);
T = D.local_transition_matrix(1:C,1:C,particle);
D.prior_probabilities(1:C,particle) = obj.stationary_distribution(T);
end
else
i = sub2ind(size(D.local_transition_matrix),repmat(D.context,[obj.max_contexts+1,1]),repmat(1:obj.max_contexts+1,[obj.particles,1])',repmat(1:obj.particles,[obj.max_contexts+1,1]));
D.prior_probabilities = D.local_transition_matrix(i);
end
if D.cuesExist
i = sub2ind(size(D.local_cue_matrix),repmat(1:obj.max_contexts+1,[obj.particles,1])',repmat(obj.cues(D.t),[obj.max_contexts+1,obj.particles]),repmat(1:obj.particles,[obj.max_contexts+1,1]));
D.probability_cue = D.local_cue_matrix(i);
D.predicted_probabilities = D.prior_probabilities.*D.probability_cue;
D.predicted_probabilities = D.predicted_probabilities./sum(D.predicted_probabilities,1);
else
D.predicted_probabilities = D.prior_probabilities;
end
if any(strcmp(obj.store,'Kalman_gain_given_cstar2'))
if D.t > 1
[~,i] = max(D.predicted_probabilities,[],1);
i = sub2ind(size(D.Kalman_gains),i,1:obj.particles);
D.Kalman_gain_given_cstar2 = mean(D.Kalman_gains(i));
end
end
if any(strcmp(obj.store,'state_given_cstar2'))
if D.t > 1
[~,i] = max(D.predicted_probabilities,[],1);
i = sub2ind(size(D.state_mean),i,1:obj.particles);
D.state_given_cstar2 = mean(D.state_mean(i));
end
end
if any(strcmp(obj.store,'predicted_probability_cstar3'))
D.predicted_probability_cstar3 = mean(max(D.predicted_probabilities,[],1));
end
end
function D = predict_states(obj,D)
% propagate states
D.state_mean = D.retention.*D.state_filtered_mean + D.drift;
D.state_var = D.retention.^2.*D.state_filtered_var + obj.sigma_process_noise^2;
% index of novel states
i_new_x = sub2ind([obj.max_contexts+1,obj.particles],D.C+1,1:obj.particles);
% novel states are distributed according to the stationary distribution
D.state_mean(i_new_x) = D.drift(i_new_x)./(1-D.retention(i_new_x));
D.state_var(i_new_x) = obj.sigma_process_noise^2./(1-D.retention(i_new_x).^2);
% predict state (marginalise over contexts and particles)
% mean of distribution
D.average_state = sum(D.predicted_probabilities.*D.state_mean,'all')/obj.particles;
if any(strcmp(obj.store,'explicit'))
if D.t == 1
D.explicit = mean(D.state_mean(1,:));
else
[~,i] = max(D.responsibilities,[],1);
i = sub2ind(size(D.state_mean),i,1:obj.particles);
D.explicit = mean(D.state_mean(i));
end
end
if any(strcmp(obj.store,'state_given_cstar3'))
[~,i] = max(D.predicted_probabilities,[],1);
i = sub2ind(size(D.state_mean),i,1:obj.particles);
D.state_given_cstar3 = mean(D.state_mean(i));
end
end
function D = predict_state_feedback(obj,D)
% predict state feedback for each context
D.state_feedback_mean = D.state_mean + D.bias;
% variance of state feedback prediction for each context
D.state_feedback_var = D.state_var + D.sigma_observation_noise^2;
D = obj.compute_marginal_distribution(D);
% predict state feedback (marginalise over contexts and particles)
% mean of distribution
D.motor_output = sum(D.predicted_probabilities.*D.state_feedback_mean,'all')/obj.particles;
if any(strcmp(obj.store,'implicit'))
D.implicit = D.motor_output - D.average_state;
end
% sensory and motor noise
D.sensory_noise = obj.sigma_sensory_noise*randn;
D.motor_noise = obj.sigma_motor_noise*randn;
% state feedback
D.state_feedback = obj.perturbations(D.t) + D.sensory_noise + D.motor_noise;
% state feedback prediction error
D.prediction_error = D.state_feedback - D.state_feedback_mean;
end
function D = resample_particles(obj,D)
D.probability_state_feedback = normpdf(D.state_feedback,D.state_feedback_mean,sqrt(D.state_feedback_var)); % p(y_t|c_t)
if D.feedback_observed(D.t)
if D.cuesExist
p_c = log(D.prior_probabilities) + log(D.probability_cue) + log(D.probability_state_feedback); % log p(y_t,q_t,c_t)
else
p_c = log(D.prior_probabilities) + log(D.probability_state_feedback); % log p(y_t,c_t)
end
else
if D.cuesExist
p_c = log(D.prior_probabilities) + log(D.probability_cue);% log p(q_t,c_t)
else
p_c = log(D.prior_probabilities); % log p(c_t)
end
end
l_w = obj.log_sum_exp(p_c); % log p(y_t,q_t)
p_c = p_c - l_w; % log p(c_t|y_t,q_t)
% weights for resampling
w = exp(l_w - obj.log_sum_exp(l_w'));
% draw indices of particles to propagate
if D.feedback_observed(D.t) || D.cuesExist
D.i_resampled = obj.systematic_resampling(w);
else
D.i_resampled = 1:obj.particles;
end
% store variables of the predictive distributions (optional)
% these variables are stored before resampling (so that they do not depend on the current state feedback)
variables_stored_before_resampling = {'predicted_probabilities' 'state_feedback_mean' 'state_feedback_var' 'state_mean' 'state_var' 'Kalman_gain_given_cstar2' 'state_given_cstar2'};
for i = 1:numel(obj.store)
variable = obj.store{i};
if any(strcmp(variable,variables_stored_before_resampling)) && isfield(D,variable)
D = obj.store_function(D,variable);
end
end
% resample variables (particles)
D.previous_context = D.context(D.i_resampled);
D.prior_probabilities = D.prior_probabilities(:,D.i_resampled);
D.predicted_probabilities = D.predicted_probabilities(:,D.i_resampled);
D.responsibilities = exp(p_c(:,D.i_resampled)); % p(c_t|y_t,q_t)
D.C = D.C(D.i_resampled);
D.state_mean = D.state_mean(:,D.i_resampled);
D.state_var = D.state_var(:,D.i_resampled);
D.prediction_error = D.prediction_error(:,D.i_resampled);
D.state_feedback_var = D.state_feedback_var(:,D.i_resampled);
D.probability_state_feedback = D.probability_state_feedback(:,D.i_resampled);
D.global_transition_probabilities = D.global_transition_probabilities(:,D.i_resampled);
D.n_context = D.n_context(:,:,D.i_resampled);
D.previous_state_filtered_mean = D.state_filtered_mean(:,D.i_resampled);
D.previous_state_filtered_var = D.state_filtered_var(:,D.i_resampled);
if D.cuesExist
D.global_cue_probabilities = D.global_cue_probabilities(:,D.i_resampled);
D.n_cue = D.n_cue(:,:,D.i_resampled);
end
D.retention = D.retention(:,D.i_resampled);
D.drift = D.drift(:,D.i_resampled);
D.dynamics_ss_1 = D.dynamics_ss_1(:,D.i_resampled,:);
D.dynamics_ss_2 = D.dynamics_ss_2(:,D.i_resampled,:,:);
if obj.infer_bias
D.bias = D.bias(:,D.i_resampled);
D.bias_ss_1 = D.bias_ss_1(:,D.i_resampled);
D.bias_ss_2 = D.bias_ss_2(:,D.i_resampled);
end
end
function D = sample_context(obj,D)
% sample the context
D.context = sum(rand(1,obj.particles) > cumsum(D.responsibilities),1) + 1;
% incremement the context count
D.p_new_x = find(D.context > D.C);
D.p_old_x = find(D.context <= D.C);
D.C(D.p_new_x) = D.C(D.p_new_x) + 1;
p_beta_x = D.p_new_x(D.C(D.p_new_x) ~= obj.max_contexts);
i = sub2ind([obj.max_contexts+1,obj.particles],D.context(p_beta_x),p_beta_x);
% sample the next stick-breaking weight
beta = betarnd(1,obj.gamma_context*ones(1,numel(p_beta_x)));
% update the global transition distribution
D.global_transition_probabilities(i+1) = D.global_transition_probabilities(i).*(1-beta);
D.global_transition_probabilities(i) = D.global_transition_probabilities(i).*beta;
if D.cuesExist
if obj.cues(D.t) > D.Q
% increment the cue count
D.Q = D.Q + 1;
% sample the next stick-breaking weight
beta = betarnd(1,obj.gamma_cue*ones(1,obj.particles));
% update the global cue distribution
D.global_cue_probabilities(D.Q+1,:) = D.global_cue_probabilities(D.Q,:).*(1-beta);
D.global_cue_probabilities(D.Q,:) = D.global_cue_probabilities(D.Q,:).*beta;
end
end
end
function D = update_belief_about_states(~,D)
D.Kalman_gains = D.state_var./D.state_feedback_var;
if D.feedback_observed(D.t)
D.state_filtered_mean = D.state_mean + D.Kalman_gains.*D.prediction_error.*D.H(D.context,:)';
D.state_filtered_var = (1 - D.Kalman_gains.*D.H(D.context,:)').*D.state_var;
else
D.state_filtered_mean = D.state_mean;
D.state_filtered_var = D.state_var;
end
end
function D = sample_states(obj,D)
n_new_x = numel(D.p_new_x);
i_old_x = sub2ind([obj.max_contexts+1,obj.particles],D.context(D.p_old_x),D.p_old_x);
i_new_x = sub2ind([obj.max_contexts+1,obj.particles],D.context(D.p_new_x),D.p_new_x);
% for states that have been observed before, sample x_{t-1}, and then sample x_{t} given x_{t-1}
% sample x_{t-1} using a fixed-lag (lag 1) forward-backward smoother
g = D.retention.*D.previous_state_filtered_var./D.state_var;
m = D.previous_state_filtered_mean + g.*(D.state_filtered_mean - D.state_mean);
v = D.previous_state_filtered_var + g.*(D.state_filtered_var - D.state_var).*g;
D.previous_x_dynamics = m + sqrt(v).*randn(obj.max_contexts+1,obj.particles);
% sample x_t conditioned on x_{t-1} and y_t
if D.feedback_observed(D.t)
w = (D.retention.*D.previous_x_dynamics + D.drift)./obj.sigma_process_noise^2 + D.H(D.context,:)'./D.sigma_observation_noise^2.*(D.state_feedback - D.bias);
v = 1./(1./obj.sigma_process_noise^2 + D.H(D.context,:)'./D.sigma_observation_noise^2);
else
w = (D.retention.*D.previous_x_dynamics + D.drift)./obj.sigma_process_noise^2;
v = 1./(1./obj.sigma_process_noise^2);
end
D.x_dynamics = v.*w + sqrt(v).*randn(obj.max_contexts+1,obj.particles);
% for novel states, sample x_t from the filtering distribution
x_samp_novel = D.state_filtered_mean(i_new_x) + sqrt(D.state_filtered_var(i_new_x)).*randn(1,n_new_x);
D.x_bias = [D.x_dynamics(i_old_x) x_samp_novel];
D.i_observed = [i_old_x i_new_x];
end
function D = update_sufficient_statistics_for_parameters(obj,D)
% update the sufficient statistics for the parameters of the
% global transition probabilities
D = obj.update_sufficient_statistics_global_transition_probabilities(D);
% update the sufficient statistics for the parameters of the
% global cue probabilities
if D.cuesExist
D = obj.update_sufficient_statistics_global_cue_probabilities(D);
end
if D.t > 1
% update the sufficient statistics for the parameters of the
% state dynamics function
D = obj.update_sufficient_statistics_dynamics(D);
end
% update the sufficient statistics for the parameters of the
% observation function
if obj.infer_bias && D.feedback_observed(D.t)
D = obj.update_sufficient_statistics_bias(D);
end
end
function D = sample_parameters(obj,D)
% sample the global transition probabilities
D = obj.sample_global_transition_probabilities(D);
% update the local context transition probability matrix
D = obj.update_local_transition_matrix(D);
if D.cuesExist
% sample the global cue probabilities
D = obj.sample_global_cue_probabilities(D);
% update the local cue probability matrix
D = obj.update_local_cue_matrix(D);
end
% sample the parameters of the state dynamics function
D = obj.sample_dynamics(D);
% sample the parameters of the observation function
if obj.infer_bias
D = obj.sample_bias(D);
else
D.bias = 0;
end
end
function D = store_variables(obj,D)
if any(strcmp(obj.store,'Kalman_gain_given_cstar1'))
[~,i] = max(D.responsibilities,[],1);
i = sub2ind(size(D.Kalman_gains),i,1:obj.particles);
D.Kalman_gain_given_cstar1 = mean(D.Kalman_gains(i));
end
if any(strcmp(obj.store,'predicted_probability_cstar1'))
[~,i] = max(D.responsibilities,[],1);
i = sub2ind(size(D.predicted_probabilities),i,1:obj.particles);
D.predicted_probability_cstar1 = mean(D.predicted_probabilities(i));
end
if any(strcmp(obj.store,'state_given_cstar1'))
[~,i] = max(D.responsibilities,[],1);
i = sub2ind(size(D.state_mean),i,1:obj.particles);
D.state_given_cstar1 = mean(D.state_mean(i));
end
% store variables of the filtering distributions (optional)
% these variables are stored after resampling (so that they depend on the current state feedback)
variables_stored_before_resampling = {'predicted_probabilities' 'state_feedback_mean' 'state_feedback_var' 'state_mean' 'state_var' 'Kalman_gain_given_cstar2' 'state_given_cstar2'};
for i = 1:numel(obj.store)
variable = obj.store{i};
if ~any(strcmp(variable,variables_stored_before_resampling))
D = obj.store_function(D,variable);
end
end
end
function D = update_sufficient_statistics_dynamics(obj,D)
% augment the state vector: x_{t-1} --> [x_{t-1}; 1]
x_a = ones(obj.max_contexts+1,obj.particles,2);
x_a(:,:,1) = D.previous_x_dynamics;
% identify states that are not novel
I = reshape(sum(D.n_context,2),[obj.max_contexts+1,obj.particles]) > 0;
SS = D.x_dynamics.*x_a; % x_t*[x_{t-1}; 1]
D.dynamics_ss_1 = D.dynamics_ss_1 + SS.*I;
SS = reshape(x_a,[obj.max_contexts+1,obj.particles,2]).*reshape(x_a,[obj.max_contexts+1,obj.particles,1,2]); % [x_{t-1}; 1]*[x_{t-1}; 1]'
D.dynamics_ss_2 = D.dynamics_ss_2 + SS.*I;
end
function D = update_sufficient_statistics_bias(~,D)
D.bias_ss_1(D.i_observed) = D.bias_ss_1(D.i_observed) + (D.state_feedback - D.x_bias); % y_t - x_t
D.bias_ss_2(D.i_observed) = D.bias_ss_2(D.i_observed) + 1; % 1(c_t = j)
end
function D = update_sufficient_statistics_global_cue_probabilities(obj,D)
i = sub2ind([obj.max_contexts+1,max(obj.cues)+1,obj.particles],D.context,obj.cues(D.t)*ones(1,obj.particles),1:obj.particles); % 1(c_t = j, q_t = k)
D.n_cue(i) = D.n_cue(i) + 1;
end
function D = update_sufficient_statistics_global_transition_probabilities(obj,D)
i = sub2ind([obj.max_contexts+1,obj.max_contexts+1,obj.particles],D.previous_context,D.context,1:obj.particles); % 1(c_{t-1} = i, c_t = j)
D.n_context(i) = D.n_context(i) + 1;
end
function D = sample_dynamics(obj,D)