Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

move prior weights to model arguments #450

Merged
merged 5 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Type: Package
Package: EpiNow2
Title: Estimate Real-Time Case Counts and Time-Varying
Epidemiological Parameters
Version: 1.3.6.9007
Version: 1.3.6.9008
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
Authors@R:
c(person(given = "Sam",
family = "Abbott",
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ This release is in development. For a stable release install 1.3.5 from CRAN.
* Added content to the vignette for the estimate_truncation model. By @sbfnk in #439 and reviewed by @seabbs.
* Added a feature to the `estimate_truncation` to allow it to be applied to time series that are shorter than the truncation max. By @sbfnk in #438 and reviewed by @seabbs.
* Changed the `estimate_truncation` to use the `dist_spec` interface, deprecating existing options `max_trunc` and `trunc_dist`. By @sbfnk in #448 reviewed by @seabbs.
* Added a `weigh_delay_priors` argument to the main functions, allowing the users to choose whether to weigh delay priors by the number of data points or not. By @sbfnk in #450 and reviewed by @seabbs.

## Documentation

Expand Down
9 changes: 3 additions & 6 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -640,12 +640,11 @@ create_stan_args <- function(stan = stan_opts(),
##'
##' @param ... Named delay distributions specified using `dist_spec()`.
##' The names are assigned to IDs
##' @param ot Integer, number of observations (needed if weighing any priors)
##' with the number of observations
##' @param weight Numeric, weight associated with delay priors; default: 1
##' @return A list of variables as expected by the stan model
##' @importFrom purrr transpose map
##' @author Sebastian Funk
create_stan_delays <- function(..., ot) {
create_stan_delays <- function(..., weight = 1) {
dot_args <- list(...)
## combine delays
combined_delays <- unclass(c(...))
Expand Down Expand Up @@ -673,9 +672,7 @@ create_stan_delays <- function(..., ot) {
## map pmfs
ret$np_pmf_groups <- array(c(0, cumsum(combined_delays$np_pmf_length)) + 1)
## assign prior weights
if (any(ret$weight == 0)) {
ret$weight[ret$weight == 0] <- ot
}
ret$weight <- array(rep(weight, ret$n_p))
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
## remove auxiliary variables
ret$fixed <- NULL
ret$np_pmf_length <- NULL
Expand Down
10 changes: 1 addition & 9 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -889,11 +889,6 @@ tune_inv_gamma <- function(lower = 2, upper = 21) {
#' as coming from fixed (vs uncertain) distributions. Overrides any values
#' assigned to \code{mean_sd} and \code{sd_sd} by setting them to zero.
#' reduces compute requirement but may produce spuriously precise estimates.
#' @param prior_weight Integer, weight given to the generation time prior.
#' By default (prior_weight = 0) the priors will be weighted by the number of
#' observation data points, usually preventing the posteriors from shifting
#' much from the given distribution. Another sensible option would be 1,
#' i.e. treating the generation time distribution as a single parameter.
#' @return A list of distribution options.
#'
#' @author Sebastian Funk
Expand All @@ -910,7 +905,7 @@ tune_inv_gamma <- function(lower = 2, upper = 21) {
#' )
dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
distribution = c("lognormal", "gamma"), max,
pmf = numeric(0), fixed = FALSE, prior_weight = 0L) {
pmf = numeric(0), fixed = FALSE) {
## check if parametric or nonparametric
if (length(pmf) > 0 &&
!all(
Expand Down Expand Up @@ -978,7 +973,6 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
np_pmf_max = 0,
np_pmf = numeric(0),
np_pmf_length = integer(0),
weight = numeric(0),
fixed = integer(0)
))
} else { ## parametric fixed
Expand Down Expand Up @@ -1017,7 +1011,6 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
np_pmf_max = length(pmf),
np_pmf = pmf,
np_pmf_length = length(pmf),
weight = numeric(0),
fixed = 1L
))
}
Expand All @@ -1036,7 +1029,6 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
np_pmf_max = 0,
np_pmf = numeric(0),
np_pmf_length = integer(0),
weight = prior_weight,
fixed = 0L
)
}
Expand Down
12 changes: 11 additions & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
#' @param horizon Numeric, defaults to 7. Number of days into the future to
#' forecast.
#'
#' @param weigh_delay_priors Logical. If TRUE (default), all delay distribution
#' priors will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE, no weight will be applied,
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
#' i.e. delay distributions will be treated as a single parameters.
#'
#' @param verbose Logical, defaults to `TRUE` when used interactively and
#' otherwise `FALSE`. Should verbose debug progress messages be printed.
#' Corresponds to the "DEBUG" level from `futile.logger`. See `setup_logging`
Expand All @@ -48,6 +54,7 @@
#' @seealso epinow regional_epinow forecast_infections simulate_infections
#' @inheritParams create_stan_args
#' @inheritParams create_stan_data
#' @inheritParams create_stan_data
#' @inheritParams create_gp_data
#' @inheritParams fit_model_with_nuts
#' @inheritParams create_clean_reported_cases
Expand Down Expand Up @@ -245,6 +252,7 @@ estimate_infections <- function(reported_cases,
CrIs = c(0.2, 0.5, 0.9),
filter_leading_zeros = TRUE,
zero_threshold = Inf,
weigh_delay_priors = TRUE,
id = "estimate_infections",
verbose = interactive()) {
set_dt_single_thread()
Expand Down Expand Up @@ -311,7 +319,9 @@ estimate_infections <- function(reported_cases,
gt = generation_time,
delay = delays,
trunc = truncation,
ot = data$t - data$seeding_time - data$horizon
weight = ifelse(
weigh_delay_priors, data$t - data$seeding_time - data$horizon, 1
)
))

# Set up default settings
Expand Down
13 changes: 10 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@
#' use for estimation but not to fit to at the beginning of the time series.
#' This must be less than the number of observations.
#'
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#'
#' @param verbose Logical, should model fitting progress be returned. Defaults
#' to `interactive()`.
#'
Expand Down Expand Up @@ -129,8 +136,7 @@ estimate_secondary <- function(reports,
delays = delay_opts(
dist_spec(
mean = 2.5, mean_sd = 0.5,
sd = 0.47, sd_sd = 0.25, max = 30,
prior_weight = 1
sd = 0.47, sd_sd = 0.25, max = 30
)
),
truncation = trunc_opts(),
Expand All @@ -139,6 +145,7 @@ estimate_secondary <- function(reports,
CrIs = c(0.2, 0.5, 0.9),
priors = NULL,
model = NULL,
weigh_delay_priors = FALSE,
verbose = interactive(),
...) {
reports <- data.table::as.data.table(reports)
Expand All @@ -161,7 +168,7 @@ estimate_secondary <- function(reports,
data <- c(data, create_stan_delays(
delay = delays,
trunc = truncation,
ot = data$t
weight = ifelse(weigh_delay_priors, data$t, 1)
))

# observation model data
Expand Down
19 changes: 14 additions & 5 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@
#' @param model A compiled stan model to override the default model. May be
#' useful for package developers or those developing extensions.
#'
#' @param weigh_delay_priors Logical. If TRUE, all delay distribution priors
#' will be weighted by the number of observation data points, in doing so
#' approximately placing an independent prior at each time step and usually
#' preventing the posteriors from shifting. If FALSE (default), no weight will
#' be applied, i.e. delay distributions will be treated as a single
#' parameters.
#'
#' @param verbose Logical, should model fitting progress be returned.
#'
#' @param ... Additional parameters to pass to `rstan::sampling`.
Expand Down Expand Up @@ -134,10 +141,11 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
trunc_dist = "lognormal",
truncation = dist_spec(
mean = 0, sd = 0, mean_sd = 1, sd_sd = 1,
max = 10, prior_weight = 1L
max = 10
),
model = NULL,
CrIs = c(0.2, 0.5, 0.9),
weigh_delay_priors = FALSE,
verbose = TRUE,
...) {

Expand Down Expand Up @@ -189,7 +197,7 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
if (construct_trunc) {
truncation <- dist_spec(
mean = 0, mean_sd = 1, sd = 0, sd_sd = 1, distribution = trunc_dist,
max = trunc_max, prior_weight = 1
max = trunc_max
)
}

Expand All @@ -216,9 +224,10 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10,
obs_sets = ncol(obs_data)
)

data <- c(data,
create_stan_delays(trunc = truncation)
)
data <- c(data, create_stan_delays(
trunc = truncation,
weight = ifelse(weigh_delay_priors, data$t, 1)
))

## convert to integer
data$trunc_dist <-
Expand Down
5 changes: 2 additions & 3 deletions man/create_stan_delays.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 1 addition & 8 deletions man/dist_spec.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/estimate_secondary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 9 additions & 2 deletions man/estimate_truncation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-delays.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ test_stan_delays <- function(generation_time = generation_time_opts(),
generation_time = generation_time,
delays = delays,
truncation = truncation,
ot = 10
weight = 10
)
return(unlist(unname(data[params])))
}
Expand Down
12 changes: 6 additions & 6 deletions tests/testthat/test-estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ test_that("forecast_secondary can return values from simulated data and plot
expect_error(plot(inc_preds, new_obs = cases, from = "2020-05-01"), NA)
})

test_that("estimate_secondary works with prior_weight = 0", {
weight_0_delays <- dist_spec(
test_that("estimate_secondary works with weigh_delay_priors = TRUE", {
delays <- dist_spec(
mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30
)
inc_weight_0 <- estimate_secondary(
cases[1:60], delays = weight_0_delays,
inc_weigh <- estimate_secondary(
cases[1:60], delays = delays,
obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE),
verbose = FALSE
weigh_delay_priors = TRUE, verbose = FALSE
)
expect_s3_class(inc_weight_0, "estimate_secondary")
expect_s3_class(inc_weigh, "estimate_secondary")
})
3 changes: 1 addition & 2 deletions tests/testthat/test-estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ trunc_dist <- dist_spec(
mean_sd = 0.1,
sd = convert_to_logsd(3, 2),
sd_sd = 0.1,
max = 10,
prior_weight = 1
max = 10
)

# apply truncation to example data
Expand Down