Skip to content

Commit ad25c68

Browse files
committed
Updates to add figure 3 plotting
1 parent b260d2b commit ad25c68

File tree

7 files changed

+256
-54
lines changed

7 files changed

+256
-54
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: apm
22
Title: What the Package Does (One Line, Title Case)
3-
Version: 0.0.0.9001
3+
Version: 0.0.0.9002
44
Authors@R:
55
c(person("Thomas", "Leavitt", email = "[email protected]", role = c("aut"),
66
comment = c(ORCID = "0000-0002-3668-6409")),

R/apm_est.R

+2
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ apm_est <- function(fits, post_time, M = 0, R = 1000L, all_models = FALSE, cl =
234234
BMA_var_m = c(ATT = unname(BMA_var_m)),
235235
M = M,
236236
post_time = post_time,
237+
observed_means = fits$observed_means,
237238
pred_errors = fits$pred_errors,
239+
pred_errors_diff = fits$pred_errors_diff,
238240
BMA_weights = BMA_weights,
239241
boot_out = boot_out)
240242

R/apm_pre.R

+46-37
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858

5959
#' @export
6060
apm_pre <- function(models, data, weights = NULL, group_var, time_var,
61-
val_times, unit_var, nsim = 1000, cl = NULL,
62-
verbose = TRUE) {
61+
val_times, unit_var, nsim = 1000, cl = NULL,
62+
verbose = TRUE) {
6363

6464
# Argument checks
6565
chk::chk_not_missing(models, "`models`")
@@ -136,13 +136,25 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
136136
}
137137

138138
#Fit all estimates
139-
val_data <- val_weights <- val_fits <- val_coefs <- observed_val_means <- vector("list", nrow(grid))
140-
141-
apm_mat <- mat0 <- matrix(NA_real_,
142-
nrow = length(val_times),
143-
ncol = length(models),
144-
dimnames = list(val_times,
145-
names(models)))
139+
val_data <- val_weights <- val_fits <- val_coefs <- vector("list", nrow(grid))
140+
141+
#Get observed means at each time point
142+
times <- sort(unique(data[[time_var]]))
143+
times <- times[times <= max(val_times)]
144+
y <- model.response(model.frame(models[[1]]$formula, data = data))
145+
146+
observed_val_means <- setNames(lapply(times, function(t) {
147+
setNames(
148+
vapply(group_levels, function(g) {
149+
mean(y[data[[time_var]] == t & data[[group_var]] == g])
150+
}, numeric(1L)),
151+
group_levels
152+
)
153+
}), times)
154+
155+
apm_arr <- array(NA_real_,
156+
dim = c(length(val_times), length(models), 2L),
157+
dimnames = list(val_times, names(models), group_levels))
146158

147159
if (verbose) {
148160
cat("Fitting models...")
@@ -171,21 +183,7 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
171183
val_data[[f]] <- d[subset_i,, drop = FALSE]
172184
val_weights[[f]] <- weights[subset_i]
173185
val_coefs[[f]] <- na.omit(marginaleffects::get_coef(fit))
174-
175-
y <- model.response(model.frame(update(mod$formula, . ~ 1),
176-
data = val_data[[f]]))
177-
178-
if (model$log) {
179-
y <- exp(y)
180-
}
181-
182-
observed_val_means[[f]] <- setNames(
183-
vapply(group_levels, function(g) {
184-
.wtd_mean(y, val_weights[[f]], val_data[[f]][[group_var]] == g)
185-
}, numeric(1L)),
186-
group_levels
187-
)
188-
186+
189187
#Compute pred error
190188

191189
# Compute prediction errors for each model for each validation period using original coefs
@@ -205,23 +203,26 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
205203
group_levels
206204
)
207205

208-
pred_error <- (observed_val_means[[f]]["1"] - observed_val_means[[f]]["0"]) -
209-
(predicted_val_means_i["1"] - predicted_val_means_i["0"])
210-
211-
apm_mat[t, i] <- pred_error
206+
for (g in group_levels) {
207+
apm_arr[t, i, g] <- observed_val_means[[as.character(val_time)]][g] - predicted_val_means_i[g]
208+
}
212209

213210
val_fits[[f]] <- fit
214211

215212
grid[["time_ind"]][f] <- t
216213
grid[["model"]][f] <- i
217-
f <- f + 1
214+
f <- f + 1L
218215
}
219216
}
220217

218+
#Difference in average prediction errors
219+
apm_mat <- apm_arr[,, "1"] - apm_arr[,, "0"]
220+
221221
#Simulate to get BMA weights
222222

223223
## Joint variance of all model coefficients, clustering for unit
224224
val_vcov <- vcovSUEST(val_fits, cluster = data[[unit_var]])
225+
225226
if (verbose) {
226227
cat(" Done.\nSimulating to compute BMA weights...\n")
227228
}
@@ -238,14 +239,19 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
238239
#out_mat: all prediction errors; length(times) x length(models) x nsim
239240
out_mat <- simplify2array(pbapply::pblapply(seq_len(nsim), function(s) {
240241

241-
mat <- mat0
242+
mat <- matrix(NA_real_,
243+
nrow = length(val_times),
244+
ncol = length(models),
245+
dimnames = list(val_times, names(models)))
242246

243247
coefs <- sim_coefs[s,]
244248

245249
for (f in seq_len(nrow(grid))) {
246250
i <- grid$model[f]
247251
t <- grid$time_ind[f]
248252

253+
val_time <- val_times[t]
254+
249255
fit <- val_fits[[f]]
250256

251257
#Compute pred error
@@ -267,11 +273,9 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
267273
}, numeric(1L)),
268274
group_levels
269275
)
270-
271-
pred_error <- (observed_val_means[[f]]["1"] - predicted_val_means_s_i["1"]) -
272-
(observed_val_means[[f]]["0"] - predicted_val_means_s_i["0"])
273-
274-
mat[t, i] <- pred_error
276+
277+
mat[t, i] <- (observed_val_means[[as.character(val_time)]]["1"] - observed_val_means[[as.character(val_time)]]["0"]) -
278+
(predicted_val_means_s_i["1"] - predicted_val_means_s_i["0"])
275279
}
276280

277281
mat
@@ -290,6 +294,9 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
290294
cat("Done.\n")
291295
}
292296

297+
observed_means <- do.call("rbind", observed_val_means)
298+
rownames(observed_means) <- names(observed_val_means)
299+
293300
BMA_weights <- tabulate(optimal_models, nbins = length(models)) / nsim
294301

295302
fits <- list(models = models,
@@ -300,7 +307,9 @@ apm_pre <- function(models, data, weights = NULL, group_var, time_var,
300307
val_vcov = val_vcov,
301308
data = data,
302309
weights = weights,
303-
pred_errors = apm_mat,
310+
observed_means = observed_means,
311+
pred_errors = apm_arr,
312+
pred_errors_diff = apm_mat,
304313
BMA_weights = BMA_weights,
305314
nsim = nsim)
306315

@@ -331,7 +340,7 @@ print.apm_pre_fits <- function(x, ...) {
331340
#' @exportS3Method summary apm_pre_fits
332341
summary.apm_pre_fits <- function(object, order = NULL, ...) {
333342
out <- data.frame(bma = object$BMA_weights,
334-
err = apply(abs(object[["pred_errors"]]), 2, max),
343+
err = apply(abs(object[["pred_errors_diff"]]), 2, max),
335344
row.names = names(object$models))
336345

337346
names(out) <- c("BMA weights", "Max|errors|")

0 commit comments

Comments
 (0)