Skip to content

Commit

Permalink
Default choices (#622)
Browse files Browse the repository at this point in the history
* move choices to default arguments

* add news item

* don't allow multiples

* update order

* remove stray choices

* update order

* add reviewer

Co-authored-by: Sam Abbott <[email protected]>

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Mar 22, 2024
1 parent 5a5d009 commit a2b741a
Show file tree
Hide file tree
Showing 27 changed files with 117 additions and 97 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
* Tests have been updated to only set random seeds before snapshot tests involving random number generation, and unset them subsequently. By @sbfnk in #590 and reviewed by @seabbs.
* A function `simulate_secondary()` was added to simulate from parameters of the `estimate_secondary` model. A function of the same name that was previously based on a reimplementation of that model in R with potentially time-varying scalings and delays has been renamed to `convolve_and_scale()`. By @sbfnk in #591 and reviewed by @seabbs.
* Fixed broken links in the README. By @jamesmbaazam in #617 and reviewed by @sbfnk.
* Argument choices have been moved into default arguments. By @sbfnk in #622 and reviewed by @seabbs.

## Model changes

Expand Down
16 changes: 7 additions & 9 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,16 @@
#' @importFrom rlang arg_match
#' @return Called for its side effects.
#' @keywords internal
check_reports_valid <- function(reports, model) {
check_reports_valid <- function(reports,
model = c(
"estimate_infections",
"estimate_truncation",
"estimate_secondary"
)) {
# Check that the case time series (reports) is a data frame
assert_data_frame(reports)
# Perform checks depending on the model to the data is meant to be used with
model <- arg_match(
model,
values = c(
"estimate_infections",
"estimate_truncation",
"estimate_secondary"
)
)
model <- arg_match(model)

if (model == "estimate_secondary") {
# Check that reports has the right column names
Expand Down
12 changes: 3 additions & 9 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,11 @@ create_shifted_cases <- function(reported_cases, shift,
#' @param delay Numeric mean delay
#' @importFrom rlang arg_match
#' @return A list containing a logical called fixed and an integer called from
create_future_rt <- function(future = "latest", delay = 0) {
create_future_rt <- function(future = c("latest", "project", "estimate"),
delay = 0) {
out <- list(fixed = FALSE, from = 0)
if (is.character(future)) {
future <- arg_match(
future,
values = c(
"project",
"latest",
"estimate"
)
)
future <- arg_match(future)
if (!(future == "project")) {
out$fixed <- TRUE
out$from <- ifelse(future == "latest", 0, -delay)
Expand Down
7 changes: 4 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -439,10 +439,11 @@ plot.estimate_secondary <- function(x, primary = FALSE,
#' # Simulate secondary cases
#' cases <- convolve_and_scale(cases, type = "prevalence")
#' cases
convolve_and_scale <- function(data, type = "incidence", family = "poisson",
convolve_and_scale <- function(data, type = c("incidence", "prevalence"),
family = c("none", "poisson", "negbin"),
delay_max = 30, ...) {
type <- arg_match(type, values = c("incidence", "prevalence"))
family <- arg_match(family, values = c("none", "poisson", "negbin"))
type <- arg_match(type)
family <- arg_match(family)
data <- data.table::as.data.table(data)
data <- data.table::copy(data)
data <- data[, index := seq_len(.N)]
Expand Down
40 changes: 19 additions & 21 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,8 @@ generation_time_opts <- function(dist = Fixed(1), ...,
#'
#' # prevalence model
#' secondary_opts("prevalence")
secondary_opts <- function(type = "incidence", ...) {
type <- arg_match(
type,
values = c("incidence", "prevalence")
)
secondary_opts <- function(type = c("incidence", "prevalence"), ...) {
type <- arg_match(type)
if (type == "incidence") {
data <- list(
cumulative = 0,
Expand Down Expand Up @@ -349,7 +346,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
rw = 0,
use_breakpoints = TRUE,
future = "latest",
gp_on = "R_t-1",
gp_on = c("R_t-1", "R0"),
pop = 0) {
rt <- list(
prior = prior,
Expand All @@ -358,7 +355,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
use_breakpoints = use_breakpoints,
future = future,
pop = pop,
gp_on = arg_match(gp_on, values = c("R_t-1", "R0"))
gp_on = arg_match(gp_on)
)

# replace default settings with those specified by user
Expand Down Expand Up @@ -407,9 +404,10 @@ rt_opts <- function(prior = list(mean = 1, sd = 1),
#' @examples
#' # default settings
#' backcalc_opts()
backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) {
backcalc_opts <- function(prior = c("reports", "none", "infections"),
prior_window = 14, rt_window = 1) {
backcalc <- list(
prior = arg_match(prior, values = c("reports", "none", "infections")),
prior = arg_match(prior),
prior_window = prior_window,
rt_window = as.integer(rt_window)
)
Expand Down Expand Up @@ -482,7 +480,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_min = 0,
ls_max = 60,
alpha_sd = 0.05,
kernel = "matern_3/2",
kernel = c("matern_3/2", "se"),
matern_type = 3 / 2) {
gp <- list(
basis_prop = basis_prop,
Expand All @@ -492,7 +490,7 @@ gp_opts <- function(basis_prop = 0.2,
ls_min = ls_min,
ls_max = ls_max,
alpha_sd = alpha_sd,
kernel = arg_match(kernel, values = c("se", "matern_3/2")),
kernel = arg_match(kernel),
matern_type = matern_type
)

Expand Down Expand Up @@ -552,7 +550,7 @@ gp_opts <- function(basis_prop = 0.2,
#'
#' # Scale reported data
#' obs_opts(scale = list(mean = 0.2, sd = 0.02))
obs_opts <- function(family = "negbin",
obs_opts <- function(family = c("negbin", "poisson"),
phi = list(mean = 0, sd = 1),
weight = 1,
week_effect = TRUE,
Expand Down Expand Up @@ -581,7 +579,7 @@ obs_opts <- function(family = "negbin",
phi <- list(mean = phi[1], sd = phi[2])
}
obs <- list(
family = arg_match(family, values = c("poisson", "negbin")),
family = arg_match(family),
phi = phi,
weight = weight,
week_effect = week_effect,
Expand Down Expand Up @@ -689,10 +687,10 @@ stan_sampling_opts <- function(cores = getOption("mc.cores", 1L),
seed = as.integer(runif(1, 1, 1e8)),
future = FALSE,
max_execution_time = Inf,
backend = "rstan",
backend = c("rstan", "cmdstanr"),
...) {
dot_args <- list(...)
backend <- arg_match(backend, values = c("rstan", "cmdstanr"))
backend <- arg_match(backend)
opts <- list(
chains = chains,
save_warmup = save_warmup,
Expand Down Expand Up @@ -800,12 +798,12 @@ stan_vb_opts <- function(samples = 2000,
#' @seealso [rstan_sampling_opts()] [rstan_vb_opts()]
rstan_opts <- function(object = NULL,
samples = 2000,
method = "sampling", ...) {
method = c("sampling", "vb"), ...) {
lifecycle::deprecate_warn(
"1.5.0", "rstan_opts()",
"stan_opts()"
)
method <- arg_match(method, values = c("sampling", "vb"))
method <- arg_match(method)
# shared everywhere opts
if (is.null(object)) {
object <- stanmodels$estimate_infections
Expand Down Expand Up @@ -882,13 +880,13 @@ rstan_opts <- function(object = NULL,
#' stan_opts(method = "vb")
stan_opts <- function(object = NULL,
samples = 2000,
method = "sampling",
backend = "rstan",
method = c("sampling", "vb"),
backend = c("rstan", "cmdstanr"),
init_fit = NULL,
return_fit = TRUE,
...) {
method <- arg_match(method, values = c("sampling", "vb"))
backend <- arg_match(backend, values = c("rstan", "cmdstanr"))
method <- arg_match(method)
backend <- arg_match(backend)
if (backend == "cmdstanr" && !requireNamespace("cmdstanr", quietly = TRUE)) {
stop(
"The `cmdstanr` package needs to be installed for using the ",
Expand Down
30 changes: 15 additions & 15 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ plot_CrIs <- function(plot, CrIs, alpha, linewidth) {
#' )
plot_estimates <- function(estimate, reported, ylab = "Cases", hline,
obs_as_col = TRUE, max_plot = 10,
estimate_type = NULL) {
estimate_type = c(
"Estimate", "Estimate based on partial data",
"Forecast")
) {
# convert input to data.table
estimate <- data.table::as.data.table(estimate)
if (!missing(reported)) {
Expand All @@ -129,14 +132,7 @@ plot_estimates <- function(estimate, reported, ylab = "Cases", hline,
estimate <- estimate[, type := to_sentence(type)]

orig_estimate <- copy(estimate)
if (!is.null(estimate_type)) {
estimate_type <- arg_match(
estimate_type,
values = c("Estimate", "Estimate based on partial data", "Forecast"),
multiple = TRUE
)
estimate <- estimate[type %in% estimate_type]
}
estimate_type <- arg_match(estimate_type, multiple = TRUE)
# scale plot values based on reported cases
if (!missing(reported) && !is.na(max_plot)) {
sd_cols <- c(
Expand Down Expand Up @@ -370,9 +366,10 @@ plot_summary <- function(summary_results,
#'
#' @param x A list of output as produced by `estimate_infections`
#'
#' @param type A character vector indicating the name of plots to return.
#' @param type A character vector indicating the name of the plot to return.
#' Defaults to "summary" with supported options being "infections", "reports",
#' "R", "growth_rate", "summary", "all".
#' "R", "growth_rate", "summary", "all". If "all" is supplied all plots are
#' generated.
#'
#' @param ... Pass additional arguments to report_plots
#' @importFrom rlang arg_match
Expand All @@ -382,15 +379,18 @@ plot_summary <- function(summary_results,
#' @method plot estimate_infections
#' @return List of plots as produced by [report_plots()]
#' @export
plot.estimate_infections <- function(x, type = "summary", ...) {
plot.estimate_infections <- function(x,
type = c(
"summary", "infections", "reports", "R",
"growth_rate", "all"
), ...) {
out <- report_plots(
summarised_estimates = x$summarised,
reported = x$observations, ...
)
choices <- c("infections", "reports", "R", "growth_rate", "summary", "all")
type <- arg_match(type, values = choices, multiple = TRUE)
type <- arg_match(type)
if (type == "all") {
type <- choices[-length(choices)]
type <- c("summary", "infections", "reports", "R", "growth_rate")
}

if (!is.null(out)) {
Expand Down
14 changes: 6 additions & 8 deletions R/stan.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
#' @importFrom rlang arg_match
#' @return A `cmdstanr` model.
#' @export
package_model <- function(model = "estimate_infections",
package_model <- function(model = c(
"estimate_infections", "simulate_infections",
"estimate_secondary", "simulate_secondary",
"estimate_truncation", "dist_fit"
),
include = system.file("stan", package = "EpiNow2"),
verbose = FALSE,
...) {
model <- arg_match(
model,
c(
"estimate_infections", "simulate_infections", "estimate_secondary",
"simulate_secondary", "estimate_truncation", "dist_fit"
)
)
model <- arg_match(model)
model_file <- system.file(
"stan", paste0(model, ".stan"),
package = "EpiNow2"
Expand Down
16 changes: 10 additions & 6 deletions R/summarise.R
Original file line number Diff line number Diff line change
Expand Up @@ -752,11 +752,13 @@ calc_summary_measures <- function(samples,
#' @method summary epinow
#' @return Returns a `<data.frame>` of summary output
#' @export
summary.epinow <- function(object, output = "estimates",
summary.epinow <- function(object,
output = c(
"estimates", "forecast", "estimated_reported_cases"
),
date = NULL, params = NULL,
...) {
choices <- c("estimates", "forecast", "estimated_reported_cases")
output <- arg_match(output, values = choices, multiple = FALSE)
output <- arg_match(output)
if (output == "estimates") {
out <- summary(object$estimates,
date = date,
Expand Down Expand Up @@ -800,10 +802,12 @@ summary.epinow <- function(object, output = "estimates",
#' @method summary estimate_infections
#' @return Returns a `<data.frame>` of summary output
#' @export
summary.estimate_infections <- function(object, type = "snapshot",
summary.estimate_infections <- function(object,
type = c(
"snapshot", "parameters", "samples"
),
date = NULL, params = NULL, ...) {
choices <- c("snapshot", "parameters", "samples")
type <- arg_match(type, values = choices, multiple = FALSE)
type <- arg_match(type)
if (is.null(date)) {
target_date <- unique(
object$summarised[type != "forecast"][date == max(date)]$date
Expand Down
6 changes: 5 additions & 1 deletion man/backcalc_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/check_reports_valid.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/convolve_and_scale.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/create_future_rt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/gp_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/obs_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/package_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit a2b741a

Please sign in to comment.