Skip to content

Commit

Permalink
Add option to accumulate observations (#534)
Browse files Browse the repository at this point in the history
* add option to accumulate observations

* accumulate in estimate_secondary model

* add test for weekly accumulation

* check there's data to fit initial growth model

* ignore first observation when accumulating

* document "na" argument

* add news item

* update obs_opts tests

* make logical operator scalar

* make NA option work with estimate_secondary

* add tests

* Apply suggestions from code review

Co-authored-by: Sam Abbott <[email protected]>

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Feb 15, 2024
1 parent df1fdc8 commit 50fc3cf
Show file tree
Hide file tree
Showing 15 changed files with 222 additions and 49 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ importFrom(data.table,fwrite)
importFrom(data.table,getDTthreads)
importFrom(data.table,melt)
importFrom(data.table,merge.data.table)
importFrom(data.table,nafill)
importFrom(data.table,rbindlist)
importFrom(data.table,setDT)
importFrom(data.table,setDTthreads)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
## Model changes

* Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in # and reviewed by @sbfnk.
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs.

# EpiNow2 1.4.0

Expand Down
35 changes: 26 additions & 9 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#' @export
#' @examples
#' create_clean_reported_cases(example_confirmed, 7)
create_clean_reported_cases <- function(reported_cases, horizon,
create_clean_reported_cases <- function(reported_cases, horizon = 0,
filter_leading_zeros = TRUE,
zero_threshold = Inf,
fill = NA_integer_) {
Expand Down Expand Up @@ -75,6 +75,25 @@ create_clean_reported_cases <- function(reported_cases, horizon,
return(reported_cases)
}

#' Create complete cases
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases; data frame with a column "confirm" that may contain NA values
#' @param burn_in; integer (default 0). Number of days to remove from the
#' start of the time series be filtered out.
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @author Sebastian Funk
#' @importFrom data.table setDT
#' @keywords internal
create_complete_cases <- function(cases) {
cases <- setDT(cases)
cases[, lookup := seq_len(.N)]
cases <- cases[!is.na(cases$confirm)]
return(cases[])
}

#' Create Delay Shifted Cases
#'
#' @description `r lifecycle::badge("stable")`
Expand Down Expand Up @@ -397,6 +416,7 @@ create_obs_model <- function(obs = obs_opts(), dates) {
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.numeric(length(obs$scale) != 0),
accumulate = obs$accumulate,
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
Expand Down Expand Up @@ -447,16 +467,13 @@ create_stan_data <- function(reported_cases, seeding_time,
backcalc, shifted_cases) {

cases <- reported_cases[(seeding_time + 1):(.N - horizon)]
cases[, lookup := seq_len(.N)]
complete_cases <- cases[!is.na(cases$confirm)]
cases_time <- complete_cases$lookup
complete_cases <- complete_cases$confirm
complete_cases <- create_complete_cases(cases)
cases <- cases$confirm

data <- list(
cases = complete_cases,
cases_time = cases_time,
lt = length(cases_time),
cases = complete_cases$confirm,
cases_time = complete_cases$lookup,
lt = nrow(complete_cases),
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
Expand All @@ -481,7 +498,7 @@ create_stan_data <- function(reported_cases, seeding_time,
is.na(data$prior_infections) || is.null(data$prior_infections),
0, data$prior_infections
)
if (data$seeding_time > 1) {
if (data$seeding_time > 1 && nrow(first_week) > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
data$prior_growth <- ifelse(is.null(data$prior_growth), 0,
Expand Down
17 changes: 14 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@
#' @inheritParams calc_CrIs
#' @importFrom rstan sampling
#' @importFrom lubridate wday
#' @importFrom data.table as.data.table merge.data.table
#' @importFrom data.table as.data.table merge.data.table nafill
#' @importFrom utils modifyList
#' @importFrom checkmate assert_class assert_numeric assert_data_frame
#' assert_logical
Expand Down Expand Up @@ -166,6 +166,15 @@ estimate_secondary <- function(reports,
assert_logical(verbose)

reports <- data.table::as.data.table(reports)
secondary_reports <- reports[, list(date, confirm = secondary)]
secondary_reports <- create_clean_reported_cases(secondary_reports)
## fill in missing data (required if fitting to prevalence)
complete_secondary <- create_complete_cases(secondary_reports)

## fill down
secondary_reports[, confirm := nafill(confirm, type = "locf")]
## fill any early data up
secondary_reports[, confirm := nafill(confirm, type = "nocb")]

if (burn_in >= nrow(reports)) {
stop("burn_in is greater or equal to the number of observations.
Expand All @@ -174,8 +183,10 @@ estimate_secondary <- function(reports,
# observation and control data
data <- list(
t = nrow(reports),
obs = reports$secondary,
primary = reports$primary,
obs = secondary_reports$confirm,
obs_time = complete_secondary[lookup > burn_in]$lookup - burn_in,
lt = sum(complete_secondary$lookup > burn_in),
burn_in = burn_in,
seeding_time = 0
)
Expand Down Expand Up @@ -391,7 +402,7 @@ plot.estimate_secondary <- function(x, primary = FALSE,
from = NULL, to = NULL,
new_obs = NULL,
...) {
predictions <- data.table::copy(x$predictions)
predictions <- data.table::copy(x$predictions)[!is.na(secondary)]

if (!is.null(new_obs)) {
new_obs <- data.table::as.data.table(new_obs)
Expand Down
53 changes: 35 additions & 18 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -427,32 +427,35 @@ gp_opts <- function(basis_prop = 0.2,
#' Defines a list specifying the structure of the observation
#' model. Custom settings can be supplied which override the defaults.
#' @param family Character string defining the observation model. Options are
#' Negative binomial ("negbin"), the default, and Poisson.
#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the
#' mean and standard deviation of the normal prior used for the observation
#' process.
#'
#' @param weight Numeric, defaults to 1. Weight to give the observed data in
#' the log density.
#' Negative binomial ("negbin"), the default, and Poisson.
#' @param phi A numeric vector of length 2, defaults to 0, 1. Indicates the mean
#' and standard deviation of the normal prior used for the observation
#' process.
#' @param weight Numeric, defaults to 1. Weight to give the observed data in the
#' log density.
#' @param week_effect Logical defaulting to `TRUE`. Should a day of the week
#' effect be used in the observation model.
#'
#' effect be used in the observation model.
#' @param week_length Numeric assumed length of the week in days, defaulting to
#' 7 days. This can be modified if data aggregated over a period other than a
#' week or if data has a non-weekly periodicity.
#'
#' @param scale List, defaulting to an empty list. Should an scaling factor be
#' applied to map latent infections (convolved to date of report). If none
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#'
#' applied to map latent infections (convolved to date of report). If none
#' empty a mean (`mean`) and standard deviation (`sd`) needs to be supplied
#' defining the normally distributed scaling factor.
#' @param na Character. Options are "missing" (the default) and "accumulate".
#' This determines how NA values in the data are interpreted. If set to
#' "missing", any NA values in the observation data set will be interpreted as
#' missing and skipped in the likelihood. If set to "accumulate", modelled
#' observations will be accumulated and added to the next non-NA data point.
#' This can be used to model incidence data that is reported at less than
#' daily intervals. If set to "accumulate", the first data point is not
#' included in the likelihood but used only to reset modelled observations to
#' zero.
#' @param likelihood Logical, defaults to `TRUE`. Should the likelihood be
#' included in the model.
#'
#' included in the model.
#' @param return_likelihood Logical, defaults to `FALSE`. Should the likelihood
#' be returned by the model.
#' be returned by the model.
#' @importFrom rlang arg_match
#'
#' @return An `<obs_opts>` object of observation model settings.
#' @author Sam Abbott
#' @export
Expand All @@ -471,18 +474,32 @@ obs_opts <- function(family = "negbin",
week_effect = TRUE,
week_length = 7,
scale = list(),
na = c("missing", "accumulate"),
likelihood = TRUE,
return_likelihood = FALSE) {
if (length(phi) != 2 || !is.numeric(phi)) {
stop("phi be numeric and of length two")
}
na <- arg_match(na)
if (na == "accumulate") {
message(
"Accumulating modelled values that correspond to NA values in the data ",
"by adding them to the next non-NA data point. This means that the ",
"first data point is not included in the likelihood but used only to ",
"reset modelled observations to zero. If the first data point should be ",
"included in the likelihood this can be achieved by adding a data point ",
"of arbitrary value before the first data point."
)
}

obs <- list(
family = arg_match(family, values = c("poisson", "negbin")),
phi = phi,
weight = weight,
week_effect = week_effect,
week_length = week_length,
scale = scale,
accumulate = as.integer(na == "accumulate"),
likelihood = likelihood,
return_likelihood = return_likelihood
)
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,6 @@
real obs_weight; // weight given to observation in log density
int likelihood; // Should the likelihood be included in the model
int return_likelihood; // Should the likehood be returned by the model
int accumulate; // Should missing values be accumulated
int<lower = 0> trunc_id; // id of truncation
int<lower = 0> delay_id; // id of delay
4 changes: 2 additions & 2 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ model {
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, obs_reports[cases_time], rep_phi, phi_mean, phi_sd, model_type,
obs_weight
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
);
}
}
Expand Down
8 changes: 6 additions & 2 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ functions {

data {
int t; // time of observations
int lt; // time of observations
array[t] int<lower = 0> obs; // observed secondary data
array[lt] int obs_time; // observed secondary data
vector[t] primary; // observed primary data
int burn_in; // time period to not use for fitting
#include data/secondary.stan
Expand Down Expand Up @@ -83,8 +85,10 @@ model {
}
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1);
report_lp(
obs[(burn_in + 1):t][obs_time], obs_time, secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1, accumulate
);
}
}

Expand Down
40 changes: 32 additions & 8 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,46 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}
// update log density for reported cases
void report_lp(array[] int cases, vector reports,
void report_lp(array[] int cases, array[] int cases_time, vector reports,
array[] real rep_phi, real phi_mean, real phi_sd,
int model_type, real weight) {
int model_type, real weight, int accumulate) {
int n = num_elements(cases_time) - accumulate; // number of observations
vector[n] obs_reports; // reports at observation time
array[n] int obs_cases; // observed cases at observation time
if (accumulate) {
int t = num_elements(reports);
int i = 0;
int current_obs = 0;
obs_reports = rep_vector(0, n);
while (i <= t && current_obs <= n) {
if (current_obs > 0) { // first observation gets ignored when accumulating
obs_reports[current_obs] += reports[i];
}
if (i == cases_time[current_obs + 1]) {
current_obs += 1;
}
i += 1;
}
obs_cases = cases[2:(n + 1)];
} else {
obs_reports = reports[cases_time];
obs_cases = cases;
}
if (model_type) {
real dispersion = 1 / pow(rep_phi[model_type], 2);
real dispersion = 1 / pow(rep_phi[model_type], 2);
rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,];
if (weight == 1) {
cases ~ neg_binomial_2(reports, dispersion);
obs_cases ~ neg_binomial_2(obs_reports, dispersion);
} else {
target += neg_binomial_2_lpmf(cases | reports, dispersion) * weight;
target += neg_binomial_2_lpmf(
obs_cases | obs_reports, dispersion
) * weight;
}
} else {
if (weight == 1) {
cases ~ poisson(reports);
obs_cases ~ poisson(obs_reports);
} else {
target += poisson_lpmf(cases | reports) * weight;
target += poisson_lpmf(obs_cases | obs_reports) * weight;
}
}
}
Expand Down Expand Up @@ -97,7 +121,7 @@ array[] int report_rng(vector reports, array[] real rep_phi, int model_type) {
if (model_type) {
dispersion = 1 / pow(rep_phi[model_type], 2);
}

for (s in 1:t) {
if (reports[s] < 1e-8) {
sampled_reports[s] = 0;
Expand Down
2 changes: 1 addition & 1 deletion man/create_clean_reported_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/create_complete_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 50fc3cf

Please sign in to comment.