Skip to content

Issue #386: Add direct model allowing pass through to brms #393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
S3method(add_mean_sd,default)
S3method(add_mean_sd,gamma_samples)
S3method(add_mean_sd,lognormal_samples)
S3method(as_direct_model,data.frame)
S3method(as_latent_individual,data.frame)
S3method(epidist,default)
S3method(epidist_family_model,default)
Expand All @@ -17,9 +18,11 @@ S3method(epidist_model_prior,default)
S3method(epidist_stancode,default)
S3method(epidist_stancode,epidist_latent_individual)
S3method(epidist_validate,default)
S3method(epidist_validate,epidist_direct_model)
S3method(epidist_validate,epidist_latent_individual)
export(add_event_vars)
export(add_mean_sd)
export(as_direct_model)
export(as_latent_individual)
export(epidist)
export(epidist_diagnostics)
Expand All @@ -35,6 +38,7 @@ export(epidist_stancode)
export(epidist_validate)
export(filter_obs_by_obs_time)
export(filter_obs_by_ptime)
export(is_direct_model)
export(is_latent_individual)
export(observe_process)
export(predict_delay_parameters)
Expand Down
68 changes: 68 additions & 0 deletions R/direct_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#' Prepare direct model to pass through to `brms`
#'
#' @param data A `data.frame` containing line list data
#' @family direct_model
#' @export
as_direct_model <- function(data) {
UseMethod("as_direct_model")
}

assert_direct_model_input <- function(data) {
assert_data_frame(data)
assert_names(names(data), must.include = c("case", "ptime", "stime"))
assert_integer(data$case, lower = 0)
assert_numeric(data$ptime, lower = 0)
assert_numeric(data$stime, lower = 0)
}

#' Prepare latent individual model
#'
#' This function prepares data for use with the direct model. It does this by
#' adding columns used in the model to the `data` object provided. To do this,
#' the `data` must already have columns for the case number (integer),
#' (positive, numeric) times for the primary and secondary event times. The
#' output of this function is a `epidist_direct_model` class object, which may
#' be passed to [epidist()] to perform inference for the model.
#'
#' @param data A `data.frame` containing line list data
#' @rdname as_direct_model
#' @method as_direct_model data.frame
#' @family direct_model
#' @autoglobal
#' @export
as_direct_model.data.frame <- function(data) {
assert_direct_model_input(data)
class(data) <- c("epidist_direct_model", class(data))
data <- data |>
mutate(delay = .data$stime - .data$ptime)
epidist_validate(data)
return(data)
}

#' Validate direct model data
#'
#' This function checks whether the provided `data` object is suitable for
#' running the direct model. As well as making sure that
#' `is_direct_model()` is true, it also checks that `data` is a `data.frame`
#' with the correct columns.
#'
#' @param data A `data.frame` containing line list data
#' @param ... ...
#' @method epidist_validate epidist_direct_model
#' @family direct_model
#' @export
epidist_validate.epidist_direct_model <- function(data, ...) {
assert_true(is_direct_model(data))
assert_direct_model_input(data)
assert_names(names(data), must.include = c("case", "ptime", "stime", "delay"))
assert_numeric(data$delay, lower = 0)
}

#' Check if data has the `epidist_direct_model` class
#'
#' @param data A `data.frame` containing line list data
#' @family latent_individual
#' @export
is_direct_model <- function(data) {
inherits(data, "epidist_direct_model")
}
3 changes: 3 additions & 0 deletions R/formula.R
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,8 @@ epidist_formula_model <- function(data, formula, ...) {
#' @family formula
#' @export
epidist_formula_model.default <- function(data, formula, ...) {
formula <- stats::update(
formula, delay ~ .
)
return(formula)
}
4 changes: 4 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ reference:
desc: Specific methods for the latent individual model
contents:
- has_concept("latent_individual")
- title: Direct model
desc: Specific methods for the direct model
contents:
- has_concept("direct_model")
- title: Postprocess
desc: Functions for postprocessing model output
contents:
Expand Down
30 changes: 30 additions & 0 deletions man/as_direct_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/as_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_family_model.epidist_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_formula_model.epidist_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/epidist_validate.epidist_direct_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_validate.epidist_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/is_direct_model.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/is_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions tests/testthat/test-direct_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
test_that("as_direct_model.data.frame with default settings an object with the correct classes", { # nolint: line_length_linter.
prep_obs <- as_direct_model(sim_obs)
expect_s3_class(prep_obs, "data.frame")
expect_s3_class(prep_obs, "epidist_direct_model")
})

test_that("as_direct_model.data.frame errors when passed incorrect inputs", { # nolint: line_length_linter.
expect_error(as_direct_model(list()))
expect_error(as_direct_model(sim_obs[, 1]))
expect_error({
sim_obs$case <- paste("case_", seq_len(nrow(sim_obs)))
as_direct_model(sim_obs)
})
})

# Make this data available for other tests
prep_obs <- as_direct_model(sim_obs)
family_lognormal <- epidist_family(prep_obs, family = brms::lognormal())

test_that("is_direct_model returns TRUE for correct input", { # nolint: line_length_linter.
expect_true(is_direct_model(prep_obs))
expect_true({
x <- list()
class(x) <- "epidist_direct_model"
is_direct_model(x)
})
})

test_that("is_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_false(is_direct_model(list()))
expect_false({
x <- list()
class(x) <- "epidist_direct_model_extension"
is_direct_model(x)
})
})

test_that("epidist_validate.epidist_direct_model doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(epidist_validate(prep_obs))
})

test_that("epidist_validate.epidist_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_error(epidist_validate(list()))
expect_error(epidist_validate(prep_obs[, 1]))
expect_error({
x <- list()
class(x) <- "epidist_direct_model"
epidist_validate(x)
})
})
36 changes: 36 additions & 0 deletions tests/testthat/test-int-direct_model.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Note: some tests in this script are stochastic. As such, test failure may be
# bad luck rather than indicate an issue with the code. However, as these tests
# are reproducible, the distribution of test failures may be investigated by
# varying the input seed. Test failure at an unusually high rate does suggest
# a potential code issue.

prep_obs <- as_direct_model(sim_obs)

test_that("epidist.epidist_direct_model Stan code has no syntax errors and compiles in the default case", { # nolint: line_length_linter.
skip_on_cran()
stancode <- epidist(
data = prep_obs,
fn = brms::make_stancode,
output_dir = fs::dir_create(tempfile())
)
mod <- cmdstanr::cmdstan_model(
stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE
)
expect_true(mod$check_syntax())
expect_no_error(mod$compile())
})

test_that("epidist.epidist_direct_model fits and the MCMC converges in the default case", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
fit <- epidist(
data = prep_obs,
seed = 1,
silent = 2,
output_dir = fs::dir_create(tempfile())
)
expect_s3_class(fit, "brmsfit")
expect_s3_class(fit, "epidist_fit")
expect_convergence(fit)
})
Loading