-
Notifications
You must be signed in to change notification settings - Fork 4
/
models_03_parameter_limits.Rmd
330 lines (244 loc) · 14.2 KB
/
models_03_parameter_limits.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
---
title: 'Model Fitting: Parameter Limits'
output:
html_notebook: default
html_document:
df_print: paged
pdf_document: default
author: Marius 't Hart
---
```{r setup, cache=FALSE, include=FALSE}
library(knitr)
opts_chunk$set(comment='', eval=FALSE)
```
For some models, parameter are restricted to a certain range. In the example below, we have retention rates: how much do you remember of something you learned earlier? This can't be negative and it can't be more 100%. That's why it only makes sense to consider parameter values between 0 and 1. Here we look at how to do that.
We also look at benchmarking a little: comparing fitting algorithms in terms of speed. This may matter if you have to fit the same model to a really number of individual datasets.
# Two-rate model of visuomotor adaptation
You are now ready to implement a more interesting model, let's implement a simple version of the two-rate model of motor learning... or of visuomotor adaptation, as the case may be.
As you know, this model assumes that the reach deviation X on a trial t is the sum of the output of a fast process F and a slow process S also on trial t:
$X_t = F_t + S_t$
And each of these processes learn from the error E on the previous trial and retain some of their activity F or S from the previous trial:
$F_t = R_f \cdot F_{t-1} + L_f \cdot E_{t-1}$
and
$S_t = R_s \cdot S_{t-1} + L_s \cdot E_{t-1}$
As is customary, this model is fitted on data where there is an error clamp phase, where errors are set to zero. In those phases, the learning has no effect, as the errors are 0, but the retention factors fully determine the reach output.
First, we need some data, let's load that:
```{r}
load('data/tworatedata.rda')
```
In order to fit this model, we need to have one set of reaches, let's see the average or median across the participants. We'll want to make the model flexible, so that it can fit any kind of paradigm. This means we also need to have some way to tell the model what the paradigm is.
Before we can get the reach deviations, we need to baseline the reaches for every participant, by subtracting the median reach deviations from the aligned phase from all reach deviations:
```{r}
baseline <- function(reachvector,blidx) reachvector - mean(reachvector[blidx], na.rm=TRUE)
tworatedata[,4:ncol(tworatedata)] <- apply(tworatedata[,4:ncol(tworatedata)], FUN=baseline, MARGIN=c(2), blidx=c(17:32))
```
Now we'll get the median reach deviations across participants, and plot them.
```{r}
reaches <- apply(tworatedata[4:ncol(tworatedata)], FUN=mean, MARGIN=c(1), na.rm=TRUE)
plot(reaches,type='p',xlab='trial',ylab='reach deviation [deg]',xlim=c(0,165),ylim=c(-35,35),bty='n',ax=F)
lines(c(1,33,33,133,133,145,145),c(0,0,30,30,-30,-30,0),col='#AAAAAA')
lines(c(145,164),c(0,0),col='#AAAAAA',lty=2)
axis(1,c(1,32,132,144,164),las=2)
axis(2,c(-30,-15,0,15,30))
```
OK, so that looks like pretty decent data, with a spontaneous recovery (or rebound) between trials 144 and 164, that we can possibly explain with the two-rate model. For this model we will write two separate functions to get the mean squared error. The first needs a schedule and set of parameter values to run the model and calculate the model's reach deviations. The second needs a schedule, a set of parameter values and some actual reach data. It checks the sanity of the parameters, then uses the first function and returns the MSE - or some "large" error value if some of the model's constraints are not met.
## Model function
The function that runs the model, will use a for loop. I'm not sure exactly why, but smart people tell me this can't be solved in one line.
```{r}
twoRateModel <- function(par, schedule) {
# thse values should be zero at the start of the loop:
Et <- 0 # previous error: none
St <- 0 # state of the slow process: aligned
Ft <- 0 # state of the fast process: aligned
# we'll store what happens on each trial in these vectors:
slow <- c()
fast <- c()
total <- c()
# now we loop through the perturbations in the schedule:
for (t in c(1:length(schedule))) {
# first we calculate what the model does
# this happens before we get visual feedback about potential errors
St <- (par['Rs'] * St) - (par['Ls'] * Et)
Ft <- (par['Rf'] * Ft) - (par['Lf'] * Et)
Xt <- St + Ft
# now we calculate what the previous error will be for the next trial:
if (is.na(schedule[t])) {
Et <- 0
} else {
Et <- Xt + schedule[t]
}
# at this point we save the states in our vectors:
slow <- c(slow, St)
fast <- c(fast, Ft)
total <- c(total, Xt)
}
# after we loop through all trials, we return the model output:
return(data.frame(slow,fast,total))
}
```
We can play around with this model, by giving it some more or less common parameter values:
```{r}
par <- c('Ls'=.05, 'Lf'=.15, 'Rs'=.99, 'Rf'=.75)
schedule <- tworatedata$schedule
model <- twoRateModel(par=par, schedule=schedule)
```
That seems to work reasonably well, lets add this to our plot of the data:
```{r}
plot(reaches,type='l',col='#333333',xlab='trial',ylab='reach deviation [deg]',xlim=c(0,165),ylim=c(-35,35),bty='n',ax=F)
lines(c(1,33,33,133,133,145,145),c(0,0,30,30,-30,-30,0),col='#AAAAAA')
lines(c(145,164),c(0,0),col='#AAAAAA',lty=2)
lines(model$slow,col='blue')
lines(model$fast,col='red')
lines(model$total,col='purple')
axis(1,c(1,32,132,144,164),las=2)
axis(2,c(-30,-15,0,15,30))
```
These parameters undershoot the learning in the initial phase, as well as the rebound. How do we find better parameter values? We could play around with the settings until we find some values that work well, but that could take a long time and there is no guarantee that those values would be even close to the best parameters. Instead we want to have the computer do the heavy work, by searching the parameter space in a systematic way.
Now that we have a function that runs the model, we can use that in a function that calculates how wrong the model parameters are. We can calculate the mean of the squared difference between the total model output and the actual reach deviations that we measured. So we will definitely do that.
## Error function [1/2]
This model can be wrong in other ways too. First, all parameters have to be in between 0 and 1 (perhaps they can also be 0 or 1), second, the slow learning rate should be lower than the fast learning rate and third, the slow retention rate should be larger than the fast retention rate.
```{r}
twoRateMSE <- function(par, schedule, reaches) {
# parameter values should be between 0 and 1:
if (any(par > 1)) {
return(Inf)
}
if (any(par < 0)) {
return(Inf)
}
# learning and retention rates of the fast and slow process are constrained:
if (par['Ls'] > par['Lf']) {
return(Inf)
}
if (par['Rs'] < par['Rf']) {
return(Inf)
}
return( mean((twoRateModel(par, schedule)$total - reaches)^2, na.rm=TRUE) )
}
```
For the parameters that we just came up with, the MSE would be:
```{r}
print(twoRateMSE(par, schedule, reaches))
```
That is on average 5 degrees of on every trial, including the aligned phase. This can be improved!
## First optimization
The strategy from before was to use `optimx()` so let's see if that works:
```{r}
library(optimx)
optimx(par = par,
fn = twoRateMSE,
schedule = schedule,
reaches = reaches)
```
First of all, BFGS doesn't return anything, and while Nelder-Mead reduces the MSE to 5.2 (so ~2.3 degrees error on every trial), there is a warning about eigenvalue failures and it still takes 277 function evaluations.
## Error function [2/2]
There are some problems with the above function. First, returning `Inf` as a very large error might be too much. Let's change this to 10 times the RMSE if the model learned nothing at all. Second, we built in the boundaries in the error function, but we can set them explicitly, at least for some of the fitting methods. Let's try to fix both. First, we change the error function:
```{r}
twoRateMSE <- function(par, schedule, reaches) {
bigError <- mean(schedule^2, na.rm=TRUE) * 10
# learning and retention rates of the fast and slow process are constrained:
if (par['Ls'] > par['Lf']) {
return(bigError)
}
if (par['Rs'] < par['Rf']) {
return(bigError)
}
return( mean((twoRateModel(par, schedule)$total - reaches)^2, na.rm=TRUE) )
}
```
This is what `optimx()` does with this:
```{r}
library(optimx)
optimx(par = par,
fn = twoRateMSE,
schedule = schedule,
reaches = reaches)
```
Now both methods work and seem to provide similar model fits. However, we still want to set the upper and lower bounds in our call to `optimx()`:
```{r}
library(optimx)
optimx(par = par,
fn = twoRateMSE,
lower = c(0,0,0,0),
upper = c(1,1,1,1),
schedule = schedule,
reaches = reaches)
```
That works, but doesn't seem to change all that much. It does omit the Nelder-Mead fit... maybe that can't deal with limits on parameters?
# Benchmarking
For this data it doesn't matter so much as the model is fitted very quickly, but once you start getting into more complicated problems, you might want to think about optimizing your code a little bit. One way to do this for model fitting is to see which method runs faster. In the tables from `optimx()` we already got a clue. The column `fevals` tells us how often the error function was evaluated with each of the two methods. Nelder-Mead needed to use the error function 55 times, whereas BFGS only used 10 times. However, there might be other factors at play that affect the total time used for a single model fit.
There are ways to measure the total time it takes to run any command. If you're not interested, this section can be skipped.
We can use the `microbenchmark` package to do this. Specifically, it allows several commands to be tested at the same time, and provides some very nice output, even pretty plots if you want them. We need to specify the commands to test and how many times to test them (for more accurate results):
```{r}
library(microbenchmark)
res <- microbenchmark('Nelder-Mead' = optimx( par = par,
fn = twoRateMSE,
method='Nelder-Mead',
schedule = schedule, reaches = reaches),
'BFGS' = optimx( par = par,
fn = twoRateMSE,
method='BFGS',
schedule = schedule, reaches = reaches),
'L-BFGS-B' = optimx( par = par,
fn = twoRateMSE,
lower = c(0,0,0,0), upper=c(1,1,1,1), # boundaries!
method='L-BFGS-B',
schedule = schedule, reaches = reaches),
times = 20
)
print(res)
```
In this table you mostly want to look at the mean and median columns, as they would be most predictive of differences in runtime between different methods of fitting. Unfortunately, the results vary a lot per run in this case, perhaps because times=20 is too low for reliable results? Nelder-Mead is usually slower than the others on my machine (like the author of optim warned us), and BFGS might be somewhat faster in this case. However, they're all comparable for our purposes (finish within a second), and by using `L-BFGS-B` (whatever that stands for) we don't have to check the limits for parameters in our own code.
If you don't have `ggplot2` installed, you can use boxplot for a visualization:
```{r}
boxplot(res)
```
Or if you do have `ggplot2` you can get a prettier plot:
```{r}
if (require("ggplot2")) {
autoplot(res)
}
```
## Two-rate grid search
At this point, we can run our optimization algorithm of choice on our model so that it fits the data, but there is one possible problem that we still want to take care of: local minima. We'll use the solution we know from the previous tutorial: grid search!
```{r}
nvals <- 5
parvals <- seq(1/nvals/2,1-(1/nvals/2),1/nvals)
searchgrid <- expand.grid('Ls'=parvals,
'Lf'=parvals,
'Rs'=parvals,
'Rf'=parvals)
# evaluate starting positions:
MSE <- apply(searchgrid, FUN=twoRateMSE, MARGIN=c(1), schedule=schedule, reaches=reaches)
# run optimx on the best starting positions:
allfits <- do.call("rbind",
apply( searchgrid[order(MSE)[1:10],],
MARGIN=c(1),
FUN=optimx,
fn=twoRateMSE,
method='L-BFGS-B',
lower=c(0,0,0,0),
upper=c(1,1,1,1),
schedule=schedule,
reaches=reaches ) )
# pick the best fit:
win <- allfits[order(allfits$value)[1],]
print(win[1:5])
```
The MSE `value` is very similar to the one we had before without grid search, and so are the parameter values. Perhaps this means we didn't need grid search after all, but with this model there is a 4-dimensional parameter space and no obvious way to see if and where local minima are, so it is better to do grid search.
## Plot model
Let's plot the winning model on top of the data:
```{r}
model <- twoRateModel(par=unlist(win[,c(1:4)]), schedule=schedule)
plot(reaches,type='l',col='#333333',xlab='trial',ylab='reach deviation [deg]',xlim=c(0,165),ylim=c(-35,35),bty='n',ax=F)
lines(c(1,33,33,133,133,145,145),c(0,0,30,30,-30,-30,0),col='#AAAAAA')
lines(c(145,164),c(0,0),col='#AAAAAA',lty=2)
lines(model$slow,col='blue')
lines(model$fast,col='red')
lines(model$total,col='purple')
axis(1,c(1,32,132,144,164),las=2)
axis(2,c(-30,-15,0,15,30))
```
That _does_ look a lot better than with the parameters we made up.
# Model quality
Usually, what we will want to do is test how well one given model explains data from two different conditions or test how well two different models explain the same data. One of the assumptions there is that both model fits are really optimal model fits, so taking all the extra steps described above makes sense. Evaluating the quality of the model fits is a statistics question, so the goal is to add a tutorial about this issue.