forked from jamesheald/COIN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
DMexp_cue.m
108 lines (85 loc) · 2.68 KB
/
DMexp_cue.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
P = 1;
trials = 768;
output_motor_with_noise = true;
pert_seq_codes = zeros(1,P);
error = zeros(P,trials);
plot_and_save_all = true;
if ~(P==1)
plot_and_save_all = false;
end
for participant = 1:P
file = dir(sprintf('data/DMdata/sub%02d_*.csv', participant));
filename = file(1).name;
data = readtable(filename);
% cues = [data.cue_with_perturb].';
cues = [data.target_inds].';
perturbations = [data.perturbation].' / 0.52;
pert_seq_codes(participant) = unique(data.pert_seq_code);
% create an object of the COIN class
obj = COIN;
obj.perturbations = perturbations;
obj.cues = cues;
obj.renumber_cues;
obj.runs = 5;
obj.max_cores = feature('numcores');
if plot_and_save_all
obj.infer_bias = true;
obj.plot_state_given_context = true;
obj.plot_predicted_probabilities = true;
obj.plot_state = true;
obj.plot_bias_given_context = true;
obj.plot_bias = true;
obj.plot_state_feedback = true;
obj.plot_explicit_component = true;
obj.plot_implicit_component = true;
end
OUTPUT = obj.simulate_COIN;
for run = 1:obj.runs
noiseless_motor_output = OUTPUT.runs{run}.motor_output;
motor_noise = randn(trials,1)*obj.sigma_motor_noise;
motor_output(run,:) = noiseless_motor_output + motor_noise;
state_feedback_output(run,:) = OUTPUT.runs{run}.state_feedback;
end
error(participant,:) = obj.perturbations - OUTPUT.weights*motor_output;
end
pert_seq_codes = 1 - 2 * pert_seq_codes;
error_matrix = diag(pert_seq_codes) * error;
% In trial 72 of subject 14, the perturbation is wrong
% ignore 14th element when taking average
% Remove the 14th row
if P == 20
error_matrix(14, :) = [];
end
% Compute the mean
error_ave = mean(error_matrix);
% figure('Position', [50, 50, 900, 400]);
% plot(error_ave)
% grid on;
% xlim([0,800]);
% ylim([-1.2,1.2]);
% xlabel('Trials')
% ylabel('Motor Error')
% % title("COIN error average with cues")
% legend('off');
% print('sim_perturbcue_real_pert.png', '-dpng'); % Saves the figure as a PNG file
% Save all figures for a single simulation
if plot_and_save_all
figure
plot(error)
grid on;
xlim([0,800]);
ylim([-1.2,1.2]);
xlabel('Trials')
ylabel('Motor Error')
% title("COIN error average without cues")
legend('off');
for i = 1:10
figure(i);
% Set figure size [left, bottom, width, height]
set(gcf, 'Position', [10, 10, 400, 300]);
% Append the folder 'figures' before the filename
print(fullfile('.', 'figureoutput_COIN', ['figure' num2str(i) '.png']), '-dpng');
end
% Close all figure windows
% close all;
end