Skip to content

Commit

Permalink
move prior weights to model arguments (#450)
Browse files Browse the repository at this point in the history
* move prior weights to model arguments

* fix typo in test

* Add PR and @seabbs to news so I can steal credit

* redoc

* rename and expand doc

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Sep 8, 2023
1 parent d66f736 commit 0028686
Show file tree
Hide file tree
Showing 15 changed files with 77 additions and 48 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Type: Package
Package: EpiNow2
Title: Estimate Real-Time Case Counts and Time-Varying
Epidemiological Parameters
Version: 1.3.6.9007
Version: 1.3.6.9008
Authors@R:
c(person(given = "Sam",
family = "Abbott",
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ This release is in development. For a stable release install 1.3.5 from CRAN.
* Added content to the vignette for the estimate_truncation model. By @sbfnk in #439 and reviewed by @seabbs.
* Added a feature to the `estimate_truncation` to allow it to be applied to time series that are shorter than the truncation max. By @sbfnk in #438 and reviewed by @seabbs.
* Changed the `estimate_truncation` to use the `dist_spec` interface, deprecating existing options `max_trunc` and `trunc_dist`. By @sbfnk in #448 reviewed by @seabbs.
* Added a `weigh_delay_priors` argument to the main functions, allowing the users to choose whether to weigh delay priors by the number of data points or not. By @sbfnk in #450 and reviewed by @seabbs.

## Documentation

Expand Down
9 changes: 3 additions & 6 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -640,12 +640,11 @@ create_stan_args <- function(stan = stan_opts(),
##'
##' @param ... Named delay distributions specified using `dist_spec()`.
##' The names are assigned to IDs
##' @param ot Integer, number of observations (needed if weighing any priors)
##' with the number of observations
##' @param weight Numeric, weight associated with delay priors; default: 1
##' @return A list of variables as expected by the stan model
##' @importFrom purrr transpose map
##' @author Sebastian Funk
create_stan_delays <- function(..., ot) {
create_stan_delays <- function(..., weight = 1) {
dot_args <- list(...)
## combine delays
combined_delays <- unclass(c(...))
Expand Down Expand Up @@ -673,9 +672,7 @@ create_stan_delays <- function(..., ot) {
## map pmfs
ret$np_pmf_groups <- array(c(0, cumsum(combined_delays$np_pmf_length)) + 1)
## assign prior weights
if (any(ret$weight == 0)) {
ret$weight[ret$weight == 0] <- ot
}
ret$weight <- array(rep(weight, ret$n_p))
## remove auxiliary variables
ret$fixed <- NULL
ret$np_pmf_length <- NULL
Expand Down
10 changes: 1 addition & 9 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -889,11 +889,6 @@ tune_inv_gamma <- function(lower = 2, upper = 21) {
#' as coming from fixed (vs uncertain) distributions. Overrides any values
#' assigned to \code{mean_sd} and \code{sd_sd} by setting them to zero.
#' reduces compute requirement but may produce spuriously precise estimates.
#' @param prior_weight Integer, weight given to the generation time prior.
#' By default (prior_weight = 0) the priors will be weighted by the number of
#' observation data points, usually preventing the posteriors from shifting
#' much from the given distribution. Another sensible option would be 1,
#' i.e. treating the generation time distribution as a single parameter.
#' @return A list of distribution options.
#'
#' @author Sebastian Funk
Expand All @@ -910,7 +905,7 @@ tune_inv_gamma <- function(lower = 2, upper = 21) {
#' )
dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
distribution = c("lognormal", "gamma"), max,
pmf = numeric(0), fixed = FALSE, prior_weight = 0L) {
pmf = numeric(0), fixed = FALSE) {
## check if parametric or nonparametric
if (length(pmf) > 0 &&
!all(
Expand Down Expand Up @@ -978,7 +973,6 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
np_pmf_max = 0,
np_pmf = numeric(0),
np_pmf_length = integer(0),
weight = numeric(0),
fixed = integer(0)
))
} else { ## parametric fixed
Expand Down Expand Up @@ -1017,7 +1011,6 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
np_pmf_max = length(pmf),
np_pmf = pmf,
np_pmf_length = length(pmf),
weight = numeric(0),
fixed = 1L
))
}
Expand All @@ -1036,7 +1029,6 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
np_pmf_max = 0,
np_pmf = numeric(0),
np_pmf_length = integer(0),
weight = prior_weight,
fixed = 0L
)
}
Expand Down
12 changes: 11 additions & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#'
#' @param weigh_delay_priors Logical. If TRUE (default), all delay distribution
#' priors will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE, no weight will be applied,
#' i.e. delay distributions will be treated as a single parameters.
#'
#' @param verbose Logical, defaults to `TRUE` when used interactively and
#' otherwise `FALSE`. Should verbose debug progress messages be printed.
#' Corresponds to the "DEBUG" level from `futile.logger`. See `setup_logging`
Expand All @@ -48,6 +54,7 @@
#' @seealso epinow regional_epinow forecast_infections simulate_infections
#' @inheritParams create_stan_args
#' @inheritParams create_stan_data
#' @inheritParams create_stan_data
#' @inheritParams create_gp_data
#' @inheritParams fit_model_with_nuts
#' @inheritParams create_clean_reported_cases
Expand Down Expand Up @@ -245,6 +252,7 @@ estimate_infections <- function(reported_cases,
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = TRUE,
zero_threshold = Inf,
weigh_delay_priors = TRUE,
id = "estimate_infections",
verbose = interactive()) {
set_dt_single_thread()
Expand Down Expand Up @@ -311,7 +319,9 @@ estimate_infections <- function(reported_cases,
gt = generation_time,
delay = delays,
trunc = truncation,
ot = data$t - data$seeding_time - data$horizon
weight = ifelse(
weigh_delay_priors, data$t - data$seeding_time - data$horizon, 1
)
))

# Set up default settings
Expand Down
13 changes: 10 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@
#' use for estimation but not to fit to at the beginning of the time series.
#' This must be less than the number of observations.
#'
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#'
#' @param verbose Logical, should model fitting progress be returned. Defaults
#' to `interactive()`.
#'
Expand Down Expand Up @@ -129,8 +136,7 @@ estimate_secondary <- function(reports,
delays = delay_opts(
dist_spec(
mean = 2.5, mean_sd = 0.5,
sd = 0.47, sd_sd = 0.25, max = 30,
prior_weight = 1
sd = 0.47, sd_sd = 0.25, max = 30
)
),
truncation = trunc_opts(),
Expand All @@ -139,6 +145,7 @@ estimate_secondary <- function(reports,
CrIs = c(0.2, 0.5, 0.9),
priors = NULL,
model = NULL,
weigh_delay_priors = FALSE,
verbose = interactive(),
...) {
reports <- data.table::as.data.table(reports)
Expand All @@ -161,7 +168,7 @@ estimate_secondary <- function(reports,
data <- c(data, create_stan_delays(
delay = delays,
trunc = truncation,
ot = data$t
weight = ifelse(weigh_delay_priors, data$t, 1)
))

# observation model data
Expand Down
19 changes: 14 additions & 5 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
#' @param model A compiled stan model to override the default model. May be
#' useful for package developers or those developing extensions.
#'
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#'
#' @param verbose Logical, should model fitting progress be returned.
#'
#' @param ... Additional parameters to pass to `rstan::sampling`.
Expand Down Expand Up @@ -134,10 +141,11 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
trunc_dist = "lognormal",
truncation = dist_spec(
mean = 0, sd = 0, mean_sd = 1, sd_sd = 1,
max = 10, prior_weight = 1L
max = 10
),
model = NULL,
CrIs = c(0.2, 0.5, 0.9),
weigh_delay_priors = FALSE,
verbose = TRUE,
...) {

Expand Down Expand Up @@ -189,7 +197,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
if (construct_trunc) {
truncation <- dist_spec(
mean = 0, mean_sd = 1, sd = 0, sd_sd = 1, distribution = trunc_dist,
max = trunc_max, prior_weight = 1
max = trunc_max
)
}

Expand All @@ -216,9 +224,10 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
obs_sets = ncol(obs_data)
)

data <- c(data,
create_stan_delays(trunc = truncation)
)
data <- c(data, create_stan_delays(
trunc = truncation,
weight = ifelse(weigh_delay_priors, data$t, 1)
))

## convert to integer
data$trunc_dist <-
Expand Down
5 changes: 2 additions & 3 deletions man/create_stan_delays.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions man/dist_spec.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/estimate_secondary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions man/estimate_truncation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-delays.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_stan_delays <- function(generation_time = generation_time_opts(),
generation_time = generation_time,
delays = delays,
truncation = truncation,
ot = 10
weight = 10
)
return(unlist(unname(data[params])))
}
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ test_that("forecast_secondary can return values from simulated data and plot
expect_error(plot(inc_preds, new_obs = cases, from = "2020-05-01"), NA)
})

test_that("estimate_secondary works with prior_weight = 0", {
weight_0_delays <- dist_spec(
test_that("estimate_secondary works with weigh_delay_priors = TRUE", {
delays <- dist_spec(
mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30
)
inc_weight_0 <- estimate_secondary(
cases[1:60], delays = weight_0_delays,
inc_weigh <- estimate_secondary(
cases[1:60], delays = delays,
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE),
verbose = FALSE
weigh_delay_priors = TRUE, verbose = FALSE
)
expect_s3_class(inc_weight_0, "estimate_secondary")
expect_s3_class(inc_weigh, "estimate_secondary")
})
3 changes: 1 addition & 2 deletions tests/testthat/test-estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ trunc_dist <- dist_spec(
mean_sd = 0.1,
sd = convert_to_logsd(3, 2),
sd_sd = 0.1,
max = 10,
prior_weight = 1
max = 10
)

# apply truncation to example data
Expand Down

0 comments on commit 0028686

Please sign in to comment.