Skip to content

Commit 0665371

Browse files
committed
Updates
1 parent 4742c40 commit 0665371

11 files changed

+147
-499
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: apm
22
Title: Averaged Prediction Models
3-
Version: 0.0.0.9002
3+
Version: 0.1.0
44
Authors@R:
55
c(person("Thomas", "Leavitt", email = "[email protected]", role = c("aut"),
66
comment = c(ORCID = "0000-0002-3668-6409")),

R/apm_est.R

+7-23
Original file line numberDiff line numberDiff line change
@@ -113,19 +113,15 @@ apm_est <- function(fits, post_time, M = 0, R = 1000L, all_models = FALSE, cl =
113113
BMA_weights <- fits$BMA_weights
114114

115115
#Remove models that won't contribute
116-
if (!all_models) {
116+
if (all_models) {
117+
models_to_keep <- seq_along(models)
118+
fits_to_keep <- seq_along(fits$val_fits)
119+
}
120+
else {
117121
models_to_keep <- which(BMA_weights > 0)
118122
fits_to_keep <- which(grid$model %in% models_to_keep)
119123

120-
# models <- models[models_to_keep]
121124
BMA_weights <- BMA_weights[models_to_keep]
122-
123-
# grid <- grid[fits_to_keep, , drop = FALSE]
124-
# fits$val_fits <- fits$val_fits[fits_to_keep]
125-
}
126-
else {
127-
models_to_keep <- seq_along(models)
128-
fits_to_keep <- seq_along(fits$val_fits)
129125
}
130126

131127
#Prep everything for bootstrap that doesn't involve weights
@@ -163,10 +159,8 @@ apm_est <- function(fits, post_time, M = 0, R = 1000L, all_models = FALSE, cl =
163159
ti <- grid[["time_ind"]][fi]
164160

165161
d <- mods[[mi]]$data
166-
167-
time <- val_times[ti]
168-
169-
.subset_f_post_list[[fi]] <- which(d[[time_var]] == time)
162+
163+
.subset_f_post_list[[fi]] <- which(d[[time_var]] == val_times[ti])
170164

171165
.val_data_f_val_list[[fi]] <- d[.subset_f_post_list[[fi]], , drop = FALSE]
172166

@@ -194,10 +188,6 @@ apm_est <- function(fits, post_time, M = 0, R = 1000L, all_models = FALSE, cl =
194188
for (mi in models_to_keep) {
195189
model <- models[[mi]]
196190

197-
d <- mods[[mi]]$data
198-
199-
time <- post_time
200-
201191
.subset_mi <- .subset_m_post_list[[mi]]
202192

203193
.val_data_mi <- .val_data_m_post_list[[mi]]
@@ -260,14 +250,8 @@ apm_est <- function(fits, post_time, M = 0, R = 1000L, all_models = FALSE, cl =
260250
mi <- grid[["model"]][fi]
261251
ti <- grid[["time_ind"]][fi]
262252

263-
d <- mods[[mi]]$data
264-
265-
time <- val_times[ti]
266-
267253
.subset_fi <- .subset_f_post_list[[fi]]
268254

269-
.val_data_fi <- .val_data_f_val_list[[fi]]
270-
271255
.val_weights_fi <- .weights[.subset_fi] * weights[.subset_fi]
272256

273257
.val_groups_fi <- .val_groups_f_val_list[[fi]]

R/apm_pre.R

+5-8
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
227227
}
228228

229229
#Difference in average prediction errors
230-
pred_error_diffs_mat <- pred_errors_array[, , "1"] - pred_errors_array[, , "0"]
230+
pred_error_diffs_mat <- pred_errors_array[, , 2L] - pred_errors_array[, , 1L]
231231

232232
#Simulate to get BMA weights
233233

@@ -277,15 +277,12 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
277277
p <- exp(p)
278278
}
279279

280-
predicted_val_means_s_i <- setNames(
281-
vapply(group_levels, function(g) {
280+
predicted_val_means_s_i <- vapply(group_levels, function(g) {
282281
.wtd_mean(p, val_weights[[f]], val_groups[[f]][[g]])
283-
}, numeric(1L)),
284-
group_levels
285-
)
282+
}, numeric(1L))
286283

287-
mat[ti, mi] <- (observed_val_means[[val_time_c]]["1"] - observed_val_means[[val_time_c]]["0"]) -
288-
(predicted_val_means_s_i["1"] - predicted_val_means_s_i["0"])
284+
mat[ti, mi] <- (observed_val_means[[val_time_c]][2L] - observed_val_means[[val_time_c]][1L]) -
285+
(predicted_val_means_s_i[2L] - predicted_val_means_s_i[1L])
289286
}
290287

291288
mat

R/plot.apm_pre_fits.R

+5-5
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,15 @@
1414
#' A `ggplot` object, which can be manipulated using `ggplot2` syntax (after loading `ggplot2`).
1515
#'
1616
#' @details
17-
#' When `type = "weights"`, `plot()` displays a bar plot with a bar for each model with height equal to the BMA weight/posterior probability of selection for that model. (Note that the plot margins can sometimes cut off the models names; use `theme(plot.margins =)` after loading `ggplot2` to extend the left margin of the plot to ensure all text is visible. Alternatively, the axis text can be rotated using `theme(axis.text.x =)`.)
17+
#' When `type = "weights"`, `plot()` displays a bar plot with a bar for each model with height equal to the BMA weight/posterior probability of selection for that model. (Note that the plot margins can sometimes cut off the model names; use `theme(plot.margins =)` after loading `ggplot2` to extend the left margin of the plot to ensure all text is visible. Alternatively, the axis text can be rotated using `theme(axis.text.x =)`.)
1818
#'
19-
#' When `type = "errors"`, `plot()` displays a lattice of bar plots, with a plot for each model displaying the difference in average prediction errors for each validation period. The period with the largest difference in average prediction errors will be shaded black. The model with the smallest maximum absolute difference in average prediction errors will have a gray label.
19+
#' When `type = "errors"`, `plot()` displays a lattice of bar plots with a plot for each model displaying the difference in average prediction errors for each validation period. The period with the largest difference in average prediction errors will be shaded black. The model with the smallest maximum absolute difference in average prediction errors will have a gray label.
2020
#'
21-
#' When `type = "predict"`, `plot()` displays a lattice of line plots, with a plot for each model displaying the observed and predicted outcomes for each validation period under each model. The observed outcomes are displayed as points, while the predicted outcomes are displayed as lines.
21+
#' When `type = "predict"`, `plot()` displays a lattice of line plots with a plot for each model displaying the observed and predicted outcomes for each validation period under each model. The observed outcomes are displayed as points, while the predicted outcomes are displayed as lines.
2222
#'
23-
#' When `type = "corrected"`, `plot()` displays a lattice of line plots, with a plot for each model displaying the observed and corrected predictions for the treated group for each validation period under each model. The observed outcomes are displayed as points, while the corrected predictions are displayed as lines. Corrected predictions are computed as the observed outcome in the treated group minus the prediction error in the treated group plus the prediction error in the control group.
23+
#' When `type = "corrected"`, `plot()` displays a lattice of line plots with a plot for each model displaying the observed and corrected predictions for the treated group for each validation period under each model. The observed outcomes are displayed as points, while the corrected predictions are displayed as lines. Corrected predictions are computed as the observed outcome in the treated group minus the prediction error in the treated group plus the prediction error in the control group.
2424
#'
25-
#' @seealso [apm_pre()] to to compute the difference in average prediction errors and BMA weights; `ggplot2::geom_col()`, which is used to create the plots.
25+
#' @seealso [apm_pre()] to to compute the difference in average prediction errors and BMA weights; [ggplot2::geom_col()], which is used to create the plots.
2626
#'
2727
#' @examples
2828
#' data("ptpdata")

R/utils.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
.wtd_mean <- function(x, w = NULL, subset = NULL) {
33
if (!is.null(subset)) {
44
if (is.null(w)) {
5-
return(.wtd_mean(x[subset]))
5+
return(Recall(x[subset]))
66
}
77

8-
return(.wtd_mean(x[subset], w = w[subset]))
8+
return(Recall(x[subset], w = w[subset]))
99
}
1010

1111
if (is.null(w)) {
@@ -19,10 +19,10 @@
1919
.wtd_sd <- function(x, w = NULL, subset = NULL) {
2020
if (!is.null(subset)) {
2121
if (is.null(w)) {
22-
return(.wtd_sd(x[subset]))
22+
return(Recall(x[subset]))
2323
}
2424

25-
return(.wtd_sd(x[subset], w = w[subset]))
25+
return(Recall(x[subset], w = w[subset]))
2626
}
2727

2828
if (is.null(w)) {

README.Rmd

+12-51
Original file line numberDiff line numberDiff line change
@@ -18,65 +18,26 @@ knitr::opts_chunk$set(
1818
<!-- badges: start -->
1919
<!-- badges: end -->
2020

21-
```{r setup}
22-
library(apm)
23-
data("ptpdata")
24-
```
25-
26-
### Supplying models
27-
28-
We can specify the models to test using `apm_mod()`. This create a full cross of all supplied arguments, which include model formula, families, whether the outcome is logged or not, whether fixed effects are included or not, whether the outcome should be a difference, and whether outcome lags should appear as predictors. Below, we create a cross of 9 models.
29-
30-
```{r}
31-
models <- apm_mod(deaths ~ 1,
32-
family = list("gaussian", "quasipoisson"),
33-
log = c(TRUE, FALSE),
34-
lag = 0, diff_k = 0,
35-
time_trend = 0:2)
36-
37-
models
38-
```
21+
## Introduction
3922

40-
Normally, this cross would yield 12 = 3 (formulas) x 2 (families) x 2 (log T/F), but by default any models with non-linear links and `log = TRUE` are removed, leaving 9 models. If we want to manually add other models, we can so by creating a new models object and appending it to the current one.
23+
The `apm` package implements *Averaged Prediction Models (APM)*, a Bayesian model averaging approach for controlled pre-post designs. These designs compare differences over time between a group that becomes exposed (treated group) and one that remains unexposed (comparison group). With appropriate causal assumptions, they can identify the causal effect of the exposure/treatment.
4124

42-
```{r}
43-
models2 <- apm_mod(list(deaths ~ 1),
44-
diff_k = 1)
25+
In APM, we specify a collection of models that predict untreated outcomes. Our causal identifying assumption is that the model's prediction errors would be equal (in expectation) in the treated and comparison groups in the absence of the exposure. This is a generalization of familiar methods like Difference-in-Differences (DiD) and Comparative Interrupted Time Series (CITS).
4526

46-
models <- c(models, models2)
27+
Because many models may be plausible for this prediction task, we combine them using Bayesian model averaging. We weight each model by its robustness to violations of the causal assumption.
4728

48-
models
49-
```
50-
51-
This leaves us with 10 models.
29+
## Installation
5230

53-
### Fitting the models
31+
To install the development version from GitHub, use:
5432

55-
Next we fit all 10 models to the data. We do so once for each validation time to compute the average prediction error that will be used to select the optimal model. All models are fit simultaneously so the simulation can use the full joint distribution of model parameter estimates. For each validation time, each model is fit using a dataset that contains data points prior to that time.
33+
```{r eval = FALSE}
5634
57-
We use `apm_fit()` to fit the models, and calculate the prediction errors and BMA weights.
35+
# Install devtools if not already installed
36+
install.packages("remotes")
5837
59-
```{r}
60-
fits <- apm_pre(models,
61-
data = ptpdata,
62-
group_var = "group",
63-
time_var = "year",
64-
unit_var = "state",
65-
val_times = 2004:2007)
38+
# Install apm package from GitHub if not already installed
39+
remotes::install_github("tl2624/apm")
6640
67-
fits
6841
```
6942

70-
### Computing the ATT
71-
72-
We compute the ATT using `apm_est()`, which uses bootstrapping to compute model uncertainty due to sampling along with uncertainty due to model selection.
73-
74-
```{r}
75-
est <- apm_est(fits,
76-
post_time = 2008,
77-
M = 1)
78-
79-
est
80-
81-
summary(est)
82-
```
43+
See `vignette("apm")` for details on using the package.

0 commit comments

Comments
 (0)