-
Notifications
You must be signed in to change notification settings - Fork 0
/
Interaction detection RF.Rmd
356 lines (266 loc) · 11.1 KB
/
Interaction detection RF.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
---
title: "Interaction detection using Random Forest predictions"
author: Gertjan Verhoeven
date: September 2017
output:
pdf_document:
toc: true
toc_depth: 3
---
# summary
The idea is that comparing the predictions of an RF model with the predictions of an OLS model can inform us in what ways the OLS model fails to capture all non-linearities and interactions between the predictors. Subsequently, using partial dependence plots of the RF model can guide the modelling of the non-linearities in the OLS model. After this step, the discrepancies between the RF predictions and the OLS predictions should be caused by non-modeled interactions. Using an RF to predict the discrepancy itself can then be used to discover which predictors are involved in these interactions. We test this method on the classic `Boston Housing` dataset to predict median house values (`medv`). We indeed recover interactions that, as it turns, have already been found and documented in the literature.
# Load packages
```{r, results='hide', message=FALSE, warning=FALSE}
rm(list=ls())
library(randomForest)
library(party)
library(ranger)
library(data.table)
library(ggplot2)
library(MASS)
rdunif <- function(n,k) sample(1:k, n, replace = T)
```
# Step 1: Run a RF on the Boston Housing set
```{r}
my_ranger <- ranger(medv ~ ., data = Boston,
importance = "permutation", num.trees = 500,
mtry = 5, replace = TRUE)
```
Extract the permutation importance measure.
```{r}
myres_tmp <- ranger::importance(my_ranger);
myres <- cbind(names(myres_tmp), myres_tmp, i = 1)
#my_rownames <- row.names(myres)
myres <- data.table(myres)
setnames(myres, "V1", "varname")
setnames(myres, "myres_tmp", "MeanDecreaseAccuracy")
myres <- myres[, varname := as.factor(varname)]
myres <- myres[, MeanDecreaseAccuracy := as.numeric(MeanDecreaseAccuracy)]
myres <- myres[, i := as.integer(i)]
```
```{r}
ggplot(myres,
aes(x = reorder(factor(varname), MeanDecreaseAccuracy), y = MeanDecreaseAccuracy)) +
geom_point() + coord_flip()
```
# Fit an OLS to the Boston Housing
```{r}
my_glm <- glm(medv ~., data = Boston,
family = "gaussian")
```
# Compare predictions of both models
```{r}
pred_RF <- predict(my_ranger, data = Boston)
#pred_RF$predictions
pred_GLM <- predict(my_glm, data = Boston)
plot(pred_RF$predictions, pred_GLM)
abline(0, 1)
```
# Run a RF on the discrepancy
Discrepancy is defined as the difference between the predictions of both models for each observation.
```{r}
pred_diff <- pred_RF$predictions - pred_GLM
my_ranger_diff <- ranger(Ydiff ~ . - medv, data = data.table(Ydiff = pred_diff, Boston),
importance = "permutation", num.trees = 500,
mtry = 5, replace = TRUE)
my_ranger_diff
```
It turns out the RF can "explain" 67% of these discrepancies.
```{r}
myres_tmp <- ranger::importance(my_ranger_diff)
myres <- cbind(names(myres_tmp), myres_tmp, i = 1)
#my_rownames <- row.names(myres)
myres <- data.table(myres)
setnames(myres, "V1", "varname")
setnames(myres, "myres_tmp", "MeanDecreaseAccuracy")
myres <- myres[, varname := as.factor(varname)]
myres <- myres[, MeanDecreaseAccuracy := as.numeric(MeanDecreaseAccuracy)]
myres <- myres[, i := as.integer(i)]
```
```{r}
ggplot(myres,
aes(x = reorder(factor(varname), MeanDecreaseAccuracy), y = MeanDecreaseAccuracy)) +
geom_point() + coord_flip()
```
It turns out that `rm` and `lstat` are the variables that best predict the discrepancy.
```{r}
my_glm_int <- glm(medv ~. + rm:lstat, data = Boston,
family = "gaussian")
summary(my_glm_int)
```
The interaction we have added is indeed highly significant.
Compare approximate out-of-sample prediction accuracy using AIC:
```{r}
AIC(my_glm)
AIC(my_glm_int)
```
Indeed, the addition of the interaction greatly increases the prediction accuracy.
# Repeat this process
```{r}
pred_RF <- predict(my_ranger, data = Boston)
#pred_RF$predictions
pred_GLM <- predict(my_glm_int, data = Boston)
plot(pred_RF$predictions, pred_GLM)
abline(0, 1)
```
```{r}
pred_diff <- pred_RF$predictions - pred_GLM
my_ranger_diff2 <- ranger(Ydiff ~ . - medv, data = data.table(Ydiff = pred_diff, Boston),
importance = "permutation", num.trees = 500,
mtry = 5, replace = TRUE)
my_ranger_diff2
```
```{r}
myres_tmp <- ranger::importance(my_ranger_diff2)
myres <- cbind(names(myres_tmp), myres_tmp, i = 1)
#my_rownames <- row.names(myres)
myres <- data.table(myres)
setnames(myres, "V1", "varname")
setnames(myres, "myres_tmp", "MeanDecreaseAccuracy")
myres <- myres[, varname := as.factor(varname)]
myres <- myres[, MeanDecreaseAccuracy := as.numeric(MeanDecreaseAccuracy)]
myres <- myres[, i := as.integer(i)]
```
```{r}
ggplot(myres,
aes(x = reorder(factor(varname), MeanDecreaseAccuracy), y = MeanDecreaseAccuracy)) +
geom_point() + coord_flip()
```
Now the variables that best predict the discrepancy are `lstat` and `dis`.
Add these two variables as an interaction.
```{r}
my_glm_int2 <- glm(medv ~. + rm:lstat + lstat:dis, data = Boston,
family = "gaussian")
summary(my_glm_int2)
AIC(my_glm_int2)
AIC(my_glm_int)
```
We conclude that the second interaction also results in significant model improvement.
# A more ambitious goal: Try and improve Harrison & Rubinfeld's model formula for Boston housing
So far, we assumed that all relationships are linear.
Harrison and Rubinfeld have created a model without interactions, but with transformations to correct for skewness, heteroskedasticity etc.
Let's see if we can improve upon this model equation by applying our method to search for interactions.
Their formula predicts `log(medv)`.
```{r}
# Harrison and Rubinfeld (1978) model
my_glm_hr <- glm(log(medv) ~ I(rm^2) + age + log(dis) + log(rad) + tax + ptratio +
black + I(black^2) + log(lstat) + crim + zn + indus + chas + I(nox^2), data = Boston,
family = "gaussian")
summary(my_glm_hr)
my_ranger_log <- ranger(log(medv) ~ ., data = Boston,
importance = "permutation", num.trees = 500,
mtry = 5, replace = TRUE)
```
```{r}
pred_RF <- predict(my_ranger_log, data = Boston)
#pred_RF$predictions
pred_GLM <- predict(my_glm_hr, data = Boston)
plot(pred_RF$predictions, pred_GLM)
abline(0, 1)
```
For low predicted values both models differ in a systematic way.
This suggests that there exists a remaining pattern that is picked up by RF but not by the OLS model.
```{r}
pred_diff <- pred_RF$predictions - pred_GLM
my_ranger_log_diff <- ranger(Ydiff ~ . - medv, data = data.table(Ydiff = pred_diff, Boston),
importance = "permutation", num.trees = 500,
mtry = 5, replace = TRUE)
my_ranger_log_diff
```
The RF indicates that 54% of the discrepancy can be "explained" by RF.
```{r}
myres_tmp <- ranger::importance(my_ranger_log_diff)
myres <- cbind(names(myres_tmp), myres_tmp, i = 1)
#my_rownames <- row.names(myres)
myres <- data.table(myres)
setnames(myres, "V1", "varname")
setnames(myres, "myres_tmp", "MeanDecreaseAccuracy")
myres <- myres[, varname := as.factor(varname)]
myres <- myres[, MeanDecreaseAccuracy := as.numeric(MeanDecreaseAccuracy)]
myres <- myres[, i := as.integer(i)]
```
```{r}
ggplot(myres,
aes(x = reorder(factor(varname), MeanDecreaseAccuracy), y = MeanDecreaseAccuracy)) +
geom_point() + coord_flip()
```
Add the top 2 vars as an interaction to their model equation.
```{r}
my_glm_hr_int <- glm(log(medv) ~ I(rm^2) + age + log(dis) + log(rad) + tax + ptratio +
black + I(black^2) + log(lstat) + crim + zn + indus + chas + I(nox^2) +
lstat:nox, data = Boston,
family = "gaussian")
summary(my_glm_hr_int)
AIC(my_glm_hr)
AIC(my_glm_hr_int)
```
This results in a significant improvement!
# Repeat this procedure
```{r}
pred_RF <- predict(my_ranger_log, data = Boston)
#pred_RF$predictions
pred_GLM <- predict(my_glm_hr_int, data = Boston)
plot(pred_RF$predictions, pred_GLM)
abline(0, 1)
```
```{r}
pred_diff <- pred_RF$predictions - pred_GLM
my_ranger_log_diff2 <- ranger(Ydiff ~ . - medv, data = data.table(Ydiff = pred_diff, Boston),
importance = "permutation", num.trees = 500,
mtry = 5, replace = TRUE)
my_ranger_log_diff2
```
```{r}
myres_tmp <- ranger::importance(my_ranger_log_diff2)
myres <- cbind(names(myres_tmp), myres_tmp, i = 1)
#my_rownames <- row.names(myres)
myres <- data.table(myres)
setnames(myres, "V1", "varname")
setnames(myres, "myres_tmp", "MeanDecreaseAccuracy")
myres <- myres[, varname := as.factor(varname)]
myres <- myres[, MeanDecreaseAccuracy := as.numeric(MeanDecreaseAccuracy)]
myres <- myres[, i := as.integer(i)]
```
```{r}
ggplot(myres,
aes(x = reorder(factor(varname), MeanDecreaseAccuracy), y = MeanDecreaseAccuracy)) +
geom_point() + coord_flip()
```
Now we add lstat and dis as an interaction.
```{r}
my_glm_hr_int2 <- glm(log(medv) ~ I(rm^2) + age + log(dis) + log(rad) + tax + ptratio +
black + I(black^2) + log(lstat) + crim + zn + indus + chas + I(nox^2) +
lstat:nox + lstat:dis, data = Boston,
family = "gaussian")
summary(my_glm_hr_int2)
AIC(my_glm_hr_int2)
AIC(my_glm_hr_int)
```
And again we find an improvement in model fit.
# Have these interactions already been reported on in the literature?
Tom Minka reports on his website an analysis of interactions in the Boston Housing set:
(http://alumni.media.mit.edu/~tpminka/courses/36-350.2001/lectures/day30/)
`
> summary(fit3)
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) -227.5485 49.2363 -4.622 4.87e-06 ***
lstat 50.8553 20.3184 2.503 0.012639 *
rm 38.1245 7.0987 5.371 1.21e-07 ***
dis -16.8163 2.9174 -5.764 1.45e-08 ***
ptratio 14.9592 2.5847 5.788 1.27e-08 ***
lstat:rm -6.8143 3.1209 -2.183 0.029475 *
lstat:dis 4.8736 1.3940 3.496 0.000514 ***
lstat:ptratio -3.3209 1.0345 -3.210 0.001412 **
rm:dis 2.0295 0.4435 4.576 5.99e-06 ***
rm:ptratio -1.9911 0.3757 -5.299 1.76e-07 ***
lstat:rm:dis -0.5216 0.2242 -2.327 0.020364 *
lstat:rm:ptratio 0.3368 0.1588 2.121 0.034423 *
`
Rob mcCulloch, using BART (bayesian additive regression trees) also examines interactions in the Boston Housing data.
There the co-occurence within trees is used to discover interactions:
`The second, interaction detection, uncovers which pairs of variables interact in analogous fashion by keeping track of the percentage of trees in the sum in which both variables occur. This exploits the fact that a sum-of-trees model captures an interaction between xi and xj by using them both for splitting rules in the same tree.`
http://www.rob-mcculloch.org/some_papers_and_talks/papers/working/cgm_as.pdf
![](boston_uit_bart_book.png)
# Conclusion
We conclude that this appears a fruitfull approach to at least discovering where a regression model can be improved.