Skip to content

Commit

Permalink
CRAN check issues and fix changes from @sbfnk #303 and #304
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Aug 31, 2022
1 parent 5e2aae2 commit 99dec87
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 33 deletions.
20 changes: 12 additions & 8 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -389,23 +389,27 @@ create_stan_data <- function(reported_cases, generation_time,
truncation) {

## make sure we have at least max_gt seeding time
delays$seeding_time <- max(delays$seeding_time, generation_time$max)
delays$seeding_time <- max(delays$seeding_time, generation_time$gt_max)

## for backwards compatibility call generation_time_opts internally
generation_time <- do.call(generation_time_opts, generation_time)

if (is.null(generation_time[["gt_mean"]])) {
generation_time <- do.call(generation_time_opts, generation_time)
}

cases <- reported_cases[(delays$seeding_time + 1):(.N - horizon)]$confirm

data <- list(
cases = cases,
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
gt_mean_mean = generation_time$mean,
gt_mean_sd = generation_time$mean_sd,
gt_sd_mean = generation_time$sd,
gt_sd_sd = generation_time$sd_sd,
max_gt = generation_time$max,
gt_mean_mean = generation_time$gt_mean,
gt_mean_sd = generation_time$gt_mean_sd,
gt_sd_mean = generation_time$gt_sd,
gt_sd_sd = generation_time$gt_sd_sd,
max_gt = generation_time$gt_max,
gt_fixed = generation_time$gt_fixed,
gt_pmf = generation_time$gt_pmf,
burn_in = 0
)
# add delay data
Expand Down
40 changes: 29 additions & 11 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ generation_time_opts <- function(mean = 1, mean_sd = 0, sd = 0, sd_sd = 0,
gt <- get_generation_time(
disease = disease, source = source, max_value = max
)
names(gt) <- paste0("gt_", names(gt))
gt$gt_fixed <- fixed
}

if (fixed) {
gt$gt_mean_sd <- 0
gt$gt_sd_sd <- 0
}

## check if generation time is fixed
if (gt$gt_sd == 0 && gt$gt_sd_sd == 0) {
Expand Down Expand Up @@ -138,15 +143,20 @@ delay_opts <- function(..., fixed = FALSE) {
}

if (length(data$delay_mean_mean) > 0) {
pmf <- c()
pmf <- list()
for (i in seq_along(data$delay_mean_mean)) {
pmf <- c(pmf, discretised_lognormal_pmf(
data$delay_mean_mean[i], data$delay_sd_mean[i], data$max_delay,
pmf[[i]] <- discretised_lognormal_pmf(
data$delay_mean_mean[i], data$delay_sd_mean[i], data$max_delay[i],
reverse = TRUE
))
)
data$delay_pmf <- pmf
}
}

if (fixed) {
data$delay_mean_sd <- rep(0, length(data$delay_sd_mean))
data$delay_sd_sd <- rep(0, length(data$delay_sd_sd))
}
return(data)
}

Expand Down Expand Up @@ -180,13 +190,21 @@ delay_opts <- function(..., fixed = FALSE) {
#' trunc_opts(mean = 3, sd = 2)
trunc_opts <- function(mean = 0 , sd = 0, mean_sd = 0, sd_sd = 0, max = 0) {
present <- !(mean == 0 & sd == 0 & max == 0)
data <- list()
data$truncation <- as.numeric(present)
data$trunc_mean_mean <- ifelse(present, mean, numeric())
data$trunc_mean_sd <- ifelse(present, mean_sd, numeric())
data$trunc_sd_mean <- ifelse(present, sd, numeric())
data$trunc_sd_sd <- ifelse(present, sd_sd, numeric())
data$max_truncation <- ifelse(present, max, numeric())
data <- list(
truncation = as.numeric(present),
trunc_mean_mean = numeric(0),
trunc_mean_sd = numeric(0),
trunc_sd_mean = numeric(0),
trunc_sd_sd = numeric(0),
max_truncation = numeric(0)
)
if (present) {
data$trunc_mean_mean <- mean
data$trunc_mean_sd <- mean_sd
data$trunc_sd_mean <- sd
data$trunc_sd_sd <- sd_sd
data$max_truncation <- max
}
return(data)
}

Expand Down
12 changes: 9 additions & 3 deletions README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,12 @@ reporting_delay <- list(
Here we define the incubation period and generation time based on literature estimates for Covid-19 (see [here](https://github.com/epiforecasts/EpiNow2/tree/master/data-raw) for the code that generates these estimates). Note that these distributions may not be applicable for your use case.

```{r}
generation_time <- get_generation_time(disease = "SARS-CoV-2", source = "ganyani")
incubation_period <- get_incubation_period(disease = "SARS-CoV-2", source = "lauer")
generation_time <- generation_time_opts(
disease = "SARS-CoV-2", source = "ganyani", fixed = TRUE
)
incubation_period <- get_incubation_period(
disease = "SARS-CoV-2", source = "lauer"
)
```

### [epinow()](https://epiforecasts.io/EpiNow2/reference/epinow.html)
Expand All @@ -120,7 +124,9 @@ Estimate cases by date of infection, the time-varying reproduction number, the r
```{r, message = FALSE, warning = FALSE}
estimates <- epinow(reported_cases = reported_cases,
generation_time = generation_time,
delays = delay_opts(incubation_period, reporting_delay),
delays = delay_opts(
incubation_period, reporting_delay, fixed = FALSE
),
rt = rt_opts(prior = list(mean = 2, sd = 0.2)),
stan = stan_opts(cores = 4))
names(estimates)
Expand Down
22 changes: 11 additions & 11 deletions tests/testthat/test-stan-rt.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,57 +3,57 @@ skip_on_cran()
# Test update_Rt
test_that("update_Rt works to produce multiple Rt estimates with a static gaussian process", {
expect_equal(
update_Rt(rep(1, 10), log(1.2), rep(0, 9), rep(10, 0), numeric(0), 0),
update_Rt(10, log(1.2), rep(0, 9), rep(10, 0), numeric(0), 0),
rep(1.2, 10)
)
})
test_that("update_Rt works to produce multiple Rt estimates with a non-static gaussian process", {
expect_equal(
round(update_Rt(rep(1, 10), log(1.2), rep(0.1, 9), rep(10, 0), numeric(0), 0), 2),
round(update_Rt(10, log(1.2), rep(0.1, 9), rep(10, 0), numeric(0), 0), 2),
c(1.20, 1.33, 1.47, 1.62, 1.79, 1.98, 2.19, 2.42, 2.67, 2.95)
)
})
test_that("update_Rt works to produce multiple Rt estimates with a non-static stationary gaussian process", {
expect_equal(
round(update_Rt(rep(1, 10), log(1.2), rep(0.1, 10), rep(10, 0), numeric(0), 1), 3),
round(update_Rt(10, log(1.2), rep(0.1, 10), rep(10, 0), numeric(0), 1), 3),
c(1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326, 1.326)
)
})
test_that("update_Rt works when Rt is fixed", {
expect_equal(
round(update_Rt(rep(1, 10), log(1.2), numeric(0), rep(10, 0), numeric(0), 0), 2),
round(update_Rt(10, log(1.2), numeric(0), rep(10, 0), numeric(0), 0), 2),
rep(1.2, 10)
)
expect_equal(
round(update_Rt(rep(1, 10), log(1.2), numeric(0), rep(10, 0), numeric(0), 1), 2),
round(update_Rt(10, log(1.2), numeric(0), rep(10, 0), numeric(0), 1), 2),
rep(1.2, 10)
)
})
test_that("update_Rt works when Rt is fixed but a breakpoint is present", {
expect_equal(
round(update_Rt(rep(1, 5), log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 0), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(rep(1, 5), log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 1), 2),
round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(rep(1, 5), log(1.2), numeric(0), c(0, 1, 1, 0, 0), rep(0.1, 2), 0), 2),
round(update_Rt(5, log(1.2), numeric(0), c(0, 1, 1, 0, 0), rep(0.1, 2), 0), 2),
c(1.2, 1.33, rep(1.47, 3))
)
})
test_that("update_Rt works when Rt is variable and a breakpoint is present", {
expect_equal(
round(update_Rt(rep(1, 5), log(1.2), rep(0, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), rep(0, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(rep(1, 5), log(1.2), rep(0, 5), c(0, 0, 1, 0, 0), 0.1, 1), 2),
round(update_Rt(5, log(1.2), rep(0, 5), c(0, 0, 1, 0, 0), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(rep(1, 5), log(1.2), rep(0.1, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), rep(0.1, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
c(1.20, 1.33, 1.62, 1.79, 1.98)
)
})

0 comments on commit 99dec87

Please sign in to comment.