diff --git a/NEWS.md b/NEWS.md index 4dd574e37..bf51d48dd 100644 --- a/NEWS.md +++ b/NEWS.md @@ -12,6 +12,7 @@ - The interface for defining delay distributions has been generalised to also cater for continuous distributions - When defining probability distributions these can now be truncated using the `tolerance` argument - Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @. +- Switch to broadcasting from random walks and added unit tests. By @seabbs in #747 and reviewed by @jamesmbaazam. - Optimised convolution code to take into account the relative length of the vectors being convolved. See #745 by @seabbs and reviewed by @jamesmbaazam. - Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam. - A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs. @@ -28,6 +29,7 @@ - Updated the documentation of the dots argument of the `stan_sampling_opts()` to add that the dots are passed to `cmdstanr::sample()`. By @jamesmbaazam in #699 and reviewed by @sbfnk. - `generation_time_opts()` has been shortened to `gt_opts()` to make it easier to specify. Calls to both functions are equivalent. By @jamesmbaazam in #698 and reviewed by @seabbs and @sbfnk . +- Added stan documentation for `update_rt()`. By @seabbs in #747 and reviewed by @jamesmbaazam. # EpiNow2 1.5.2 diff --git a/R/create.R b/R/create.R index 89c61a67c..d3e8bedc6 100644 --- a/R/create.R +++ b/R/create.R @@ -261,15 +261,20 @@ create_future_rt <- function(future = c("latest", "project", "estimate"), #' #' # using breakpoints #' create_rt_data(rt_opts(use_breakpoints = TRUE), breakpoints = rep(1, 10)) +#' +#' # using random walk +#' create_rt_data(rt_opts(rw = 7), breakpoints = rep(1, 10)) #' } create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, delay = 0, horizon = 0) { + # Define if GP is on or off if (is.null(rt)) { rt <- rt_opts( use_rt = FALSE, future = "project", - gp_on = "R0" + gp_on = "R0", + rw = 0 ) } # define future Rt arguments @@ -279,24 +284,34 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, ) # apply random walk if (rt$rw != 0) { - breakpoints <- as.integer(seq_along(breakpoints) %% rt$rw == 0) + if (is.null(breakpoints)) { + stop("breakpoints must be supplied when using random walk") + } + + breakpoints <- seq_along(breakpoints) + breakpoints <- floor(breakpoints / rt$rw) if (!(rt$future == "project")) { max_bps <- length(breakpoints) - horizon + future_rt$from if (max_bps < length(breakpoints)) { - breakpoints[(max_bps + 1):length(breakpoints)] <- 0 + breakpoints[(max_bps + 1):length(breakpoints)] <- breakpoints[max_bps] } } + }else { + breakpoints <- cumsum(breakpoints) } - # check breakpoints - if (is.null(breakpoints) || sum(breakpoints) == 0) { + + if (sum(breakpoints) == 0) { rt$use_breakpoints <- FALSE } + # add a shift for 0 effect in breakpoints + breakpoints <- breakpoints + 1 + # map settings to underlying gp stan requirements rt_data <- list( r_mean = rt$prior$mean, r_sd = rt$prior$sd, estimate_r = as.numeric(rt$use_rt), - bp_n = ifelse(rt$use_breakpoints, sum(breakpoints, na.rm = TRUE), 0), + bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0), breakpoints = breakpoints, future_fixed = as.numeric(future_rt$fixed), fixed_from = future_rt$from, diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 1e97203e4..985cfa9c2 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -51,7 +51,7 @@ parameters{ array[estimate_r] real initial_infections ; // seed infections array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect - array[bp_n] real bp_effects; // Rt breakpoint effects + vector[bp_n] bp_effects; // Rt breakpoint effects // observation model vector[delay_params_length] delay_params; // delay parameters diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index 418c07675..ad2d877b1 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -1,26 +1,34 @@ -// update a vector of Rts +/** + * Update a vector of effective reproduction numbers (Rt) based on + * an intercept, breakpoints (i.e. a random walk), and a Gaussian + * process. + * + * @param t Length of the time series + * @param log_R Logarithm of the base reproduction number + * @param noise Vector of Gaussian process noise values + * @param bps Array of breakpoint indices + * @param bp_effects Vector of breakpoint effects + * @param stationary Flag indicating whether the Gaussian process is stationary + * (1) or non-stationary (0) + * @return A vector of length t containing the updated Rt values + */ vector update_Rt(int t, real log_R, vector noise, array[] int bps, - array[] real bp_effects, int stationary) { + vector bp_effects, int stationary) { // define control parameters int bp_n = num_elements(bp_effects); - int bp_c = 0; int gp_n = num_elements(noise); - // define result vectors - vector[t] bp = rep_vector(0, t); - vector[t] gp = rep_vector(0, t); - vector[t] R; - // initialise breakpoints + // initialise intercept + vector[t] R = rep_vector(log_R, t); + //initialise breakpoints + rw if (bp_n) { - for (s in 1:t) { - if (bps[s]) { - bp_c += bps[s]; - bp[s] = bp_effects[bp_c]; - } - } - bp = cumulative_sum(bp); + vector[bp_n + 1] bp0; + bp0[1] = 0; + bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects); + R = R + bp0[bps]; } //initialise gaussian process if (gp_n) { + vector[t] gp = rep_vector(0, t); if (stationary) { gp[1:gp_n] = noise; // fix future gp based on last estimated @@ -31,18 +39,31 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps, gp[2:(gp_n + 1)] = noise; gp = cumulative_sum(gp); } + R = R + gp; } - // Calculate Rt - R = rep_vector(log_R, t) + bp + gp; - R = exp(R); - return(R); + + return exp(R); } -// Rt priors + +/** + * Calculate the log-probability of the reproduction number (Rt) priors + * + * @param log_R Logarithm of the base reproduction number + * @param initial_infections Array of initial infection values + * @param initial_growth Array of initial growth rates + * @param bp_effects Vector of breakpoint effects + * @param bp_sd Array of breakpoint standard deviations + * @param bp_n Number of breakpoints + * @param seeding_time Time point at which seeding occurs + * @param r_logmean Log-mean of the prior distribution for the base reproduction number + * @param r_logsd Log-standard deviation of the prior distribution for the base reproduction number + * @param prior_infections Prior mean for initial infections + * @param prior_growth Prior mean for initial growth rates + */ void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth, - array[] real bp_effects, array[] real bp_sd, int bp_n, int seeding_time, + vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time, real r_logmean, real r_logsd, real prior_infections, real prior_growth) { - // prior on R log_R ~ normal(r_logmean, r_logsd); //breakpoint effects on Rt if (bp_n > 0) { @@ -51,6 +72,7 @@ void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_g } // initial infections initial_infections ~ normal(prior_infections, 0.2); + if (seeding_time > 1) { initial_growth ~ normal(prior_growth, 0.2); } diff --git a/man/EpiNow2-package.Rd b/man/EpiNow2-package.Rd index 3b1d4e286..f8f7529f3 100644 --- a/man/EpiNow2-package.Rd +++ b/man/EpiNow2-package.Rd @@ -43,7 +43,7 @@ Other contributors: \item Paul Mee \email{paul.mee@lshtm.ac.uk} [contributor] \item Peter Ellis \email{peter.ellis2013nz@gmail.com} [contributor] \item Pietro Monticone \email{pietro.monticone@edu.unito.it} [contributor] - \item Lloyd Chapman \email{lloyd.chapman1@lshtm.ac.uk} [contributor] + \item Lloyd Chapman \email{lloyd.chapman1@lshtm.ac.uk } [contributor] \item Andrew Johnson \email{andrew.johnson@arjohnsonau.com} [contributor] } diff --git a/man/create_rt_data.Rd b/man/create_rt_data.Rd index 4a02c1664..3c5233795 100644 --- a/man/create_rt_data.Rd +++ b/man/create_rt_data.Rd @@ -36,6 +36,9 @@ create_rt_data(rt = NULL) # using breakpoints create_rt_data(rt_opts(use_breakpoints = TRUE), breakpoints = rep(1, 10)) + +# using random walk +create_rt_data(rt_opts(rw = 7), breakpoints = rep(1, 10)) } } \seealso{ diff --git a/tests/testthat/test-create_rt_date.R b/tests/testthat/test-create_rt_date.R new file mode 100644 index 000000000..748ae80d4 --- /dev/null +++ b/tests/testthat/test-create_rt_date.R @@ -0,0 +1,88 @@ +test_that("create_rt_data returns expected default values", { + result <- create_rt_data() + + expect_type(result, "list") + expect_equal(result$r_mean, 1) + expect_equal(result$r_sd, 1) + expect_equal(result$estimate_r, 1) + expect_equal(result$bp_n, 0) + expect_equal(result$breakpoints, numeric(0)) + expect_equal(result$future_fixed, 1) + expect_equal(result$fixed_from, 0) + expect_equal(result$pop, 0) + expect_equal(result$stationary, 0) + expect_equal(result$future_time, 0) +}) + +test_that("create_rt_data handles NULL rt input correctly", { + result <- create_rt_data(rt = NULL) + + expect_equal(result$estimate_r, 0) + expect_equal(result$future_fixed, 0) + expect_equal(result$stationary, 1) +}) + +test_that("create_rt_data handles custom rt_opts correctly", { + custom_rt <- rt_opts( + prior = list(mean = 2, sd = 0.5), + use_rt = FALSE, + rw = 0, + use_breakpoints = FALSE, + future = "project", + gp_on = "R0", + pop = 1000000 + ) + + result <- create_rt_data(rt = custom_rt, horizon = 7) + + expect_equal(result$r_mean, 2) + expect_equal(result$r_sd, 0.5) + expect_equal(result$estimate_r, 0) + expect_equal(result$pop, 1000000) + expect_equal(result$stationary, 1) + expect_equal(result$future_time, 7) +}) + +test_that("create_rt_data handles breakpoints correctly", { + result <- create_rt_data(rt_opts(use_breakpoints = TRUE), + breakpoints = c(1, 0, 1, 0, 1)) + + expect_equal(result$bp_n, 3) + expect_equal(result$breakpoints, c(2, 2, 3, 3, 4)) +}) + +test_that("create_rt_data handles random walk correctly", { + result <- create_rt_data(rt_opts(rw = 2), + breakpoints = rep(1, 10)) + + expect_equal(result$bp_n, 5) + expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 5, 5, 6)) +}) + +test_that("create_rt_data throws error for invalid inputs", { + expect_error(create_rt_data(rt_opts(rw = 2)), + "breakpoints must be supplied when using random walk") +}) + +test_that("create_rt_data handles future projections correctly", { + result <- create_rt_data(rt_opts(future = "project"), horizon = 7) + + expect_equal(result$future_fixed, 0) + expect_equal(result$fixed_from, 0) + expect_equal(result$future_time, 7) +}) + +test_that("create_rt_data handles zero sum breakpoints", { + result <- create_rt_data(rt_opts(use_breakpoints = TRUE), + breakpoints = rep(0, 5)) + + expect_equal(result$bp_n, 0) +}) + +test_that("create_rt_data adjusts breakpoints for horizon", { + result <- create_rt_data(rt_opts(rw = 2, future = "latest"), + breakpoints = rep(1, 10), + horizon = 3) + + expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 4, 4, 4)) +}) diff --git a/tests/testthat/test-rt_opts.R b/tests/testthat/test-rt_opts.R new file mode 100644 index 000000000..2d415ff8e --- /dev/null +++ b/tests/testthat/test-rt_opts.R @@ -0,0 +1,61 @@ +test_that("rt_opts returns expected default values", { + result <- rt_opts() + + expect_s3_class(result, "rt_opts") + expect_equal(result$prior, list(mean = 1, sd = 1)) + expect_true(result$use_rt) + expect_equal(result$rw, 0) + expect_true(result$use_breakpoints) + expect_equal(result$future, "latest") + expect_equal(result$pop, 0) + expect_equal(result$gp_on, "R_t-1") +}) + +test_that("rt_opts handles custom inputs correctly", { + result <- rt_opts( + prior = list(mean = 2, sd = 0.5), + use_rt = FALSE, + rw = 7, + use_breakpoints = FALSE, + future = "project", + gp_on = "R0", + pop = 1000000 + ) + + expect_equal(result$prior, list(mean = 2, sd = 0.5)) + expect_false(result$use_rt) + expect_equal(result$rw, 7) + expect_true(result$use_breakpoints) # Should be TRUE when rw > 0 + expect_equal(result$future, "project") + expect_equal(result$pop, 1000000) + expect_equal(result$gp_on, "R0") +}) + +test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", { + result <- rt_opts(rw = 3, use_breakpoints = FALSE) + expect_true(result$use_breakpoints) +}) + +test_that("rt_opts throws error for invalid prior", { + expect_error(rt_opts(prior = list(mean = 1)), + "prior must have both a mean and sd specified") + expect_error(rt_opts(prior = list(sd = 1)), + "prior must have both a mean and sd specified") +}) + +test_that("rt_opts validates gp_on argument", { + expect_error(rt_opts(gp_on = "invalid"), "must be one") +}) + +test_that("rt_opts returns object of correct class", { + result <- rt_opts() + expect_s3_class(result, "rt_opts") + expect_true("list" %in% class(result)) +}) + +test_that("rt_opts handles edge cases correctly", { + result <- rt_opts(rw = 0.1, pop = -1) + expect_equal(result$rw, 0.1) + expect_equal(result$pop, -1) + expect_true(result$use_breakpoints) +}) diff --git a/tests/testthat/test-stan-rt.R b/tests/testthat/test-stan-rt.R index cddfa1ef9..1b4c40153 100644 --- a/tests/testthat/test-stan-rt.R +++ b/tests/testthat/test-stan-rt.R @@ -32,29 +32,29 @@ test_that("update_Rt works when Rt is fixed", { }) test_that("update_Rt works when Rt is fixed but a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 0), 2), + round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 1), 2), + round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), numeric(0), c(0, 1, 1, 0, 0), rep(0.1, 2), 0), 2), + round(update_Rt(5, log(1.2), numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2), c(1.2, 1.33, rep(1.47, 3)) ) }) test_that("update_Rt works when Rt is variable and a breakpoint is present", { expect_equal( - round(update_Rt(5, log(1.2), rep(0, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2), + round(update_Rt(5, log(1.2), rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0, 5), c(0, 0, 1, 0, 0), 0.1, 1), 2), + round(update_Rt(5, log(1.2), rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) expect_equal( - round(update_Rt(5, log(1.2), rep(0.1, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2), + round(update_Rt(5, log(1.2), rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), c(1.20, 1.33, 1.62, 1.79, 1.98) ) })