Skip to content

Commit

Permalink
Issue #306: Simplify prior infrastructure by allowing warning when ov…
Browse files Browse the repository at this point in the history
…erwriting (#379)

* Add abort/warn option to .replace_prior then use it in epidist_prior, and delete NULL case in code

* Fix bug (forgot to update name)

* Update tests

* Use code linking in epidist()

* Add argument documentation

* Redocument

* Add testing that the parameter set to a constant is indeed a constant

* Switch to "do nothing, or warn" prior set-up

* Better warning message and bug fix

* Expect a warning not an error

* Alter args of .replace_prior to make more external sense

* Keep old_prior as first argument
  • Loading branch information
athowes authored Oct 15, 2024
1 parent 9cc8089 commit 92c5104
Show file tree
Hide file tree
Showing 16 changed files with 74 additions and 68 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_true)
importFrom(cli,cli_abort)
importFrom(cli,cli_inform)
importFrom(cli,cli_warn)
importFrom(dplyr,filter)
importFrom(dplyr,mutate)
importFrom(dplyr,select)
2 changes: 1 addition & 1 deletion R/epidist-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@
#' @importFrom brms bf prior
#' @importFrom checkmate assert_data_frame assert_names assert_integer
#' assert_true assert_factor assert_numeric
#' @importFrom cli cli_abort cli_inform cli_abort
#' @importFrom cli cli_abort cli_inform cli_abort cli_warn
## usethis namespace: end
NULL
10 changes: 5 additions & 5 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ epidist_stancode <- function(data, ...) {
#' Fit epidemiological delay distributions using a `brms` interface
#'
#' @inheritParams epidist_validate
#' @param formula A formula object created using `brms::bf`. A formula must be
#' @param formula A formula object created using [brms::bf()]. A formula must be
#' provided for the distributional parameter `mu` common to all `brms` families.
#' Optionally, formulas may also be provided for additional distributional
#' parameters.
Expand All @@ -46,10 +46,10 @@ epidist_stancode <- function(data, ...) {
#' @param backend Character string naming the package to use as the backend for
#' fitting the Stan model. Options are `"rstan"` and `"cmdstanr"` (the default).
#' This option is passed directly through to `fn`.
#' @param fn The internal function to be called. By default this is `brms::brm`,
#' which performs inference for the specified model. Other options
#' `brms::make_stancode`, which returns the Stan code for the specified model,
#' and `brms::make_standata` which returns the data passed to Stan. These
#' @param fn The internal function to be called. By default this is
#' [brms::brm()] which performs inference for the specified model. Other options
#' [brms::make_stancode()], which returns the Stan code for the specified model,
#' and [brms::make_standata()] which returns the data passed to Stan. These
#' options may be useful for model debugging and extensions.
#' @param ... Additional arguments for method.
#' @family generics
Expand Down
26 changes: 11 additions & 15 deletions R/prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
#' 2. Family specific prior distributions from [epidist_family_prior()]
#' 3. User provided prior distributions
#' Each element of this list overwrites previous elements, such that user
#' provided prior distribution have the highest priority.
#' provided prior distribution have the highest priority. At the third stage,
#' if a prior distribution is provided which is not included in the model, then
#' a warning will be shown. To prevent this warning, do not pass prior
#' distributions for parameters which are not in the model.
#'
#' @param data ...
#' @param family ...
#' @param formula ...
#' @param prior ...
#' @param data A `data.frame` containing line list data
#' @param family Output of a call to `brms::brmsfamily()`
#' @param formula A formula object created using `brms::bf()`
#' @param prior User provided prior distribution created using `brms::prior()`
#' @rdname epidist_prior
#' @family prior
#' @export
Expand All @@ -23,7 +26,8 @@ epidist_prior <- function(data, family, formula, prior) {
default <- brms::default_prior(formula, data = data)
model <- epidist_model_prior(data, formula)
family <- epidist_family_prior(family, formula)
prior <- Reduce(.replace_prior, list(default, model, family, prior))
internal_prior <- Reduce(.replace_prior, list(default, model, family))
prior <- .replace_prior(internal_prior, prior, warn = TRUE)
return(prior)
}

Expand Down Expand Up @@ -90,15 +94,7 @@ epidist_family_prior.default <- function(family, formula, ...) {
#' @export
epidist_family_prior.lognormal <- function(family, formula, ...) {
prior <- prior("normal(1, 1)", class = "Intercept")
if ("sigma" %in% names(formula$pfix)) {
# Case with sigma fixed to a constant
sigma_prior <- NULL
} else {
# Case with a model on sigma
sigma_prior <- prior(
"normal(-0.7, 0.4)", class = "Intercept", dpar = "sigma"
)
}
sigma_prior <- prior("normal(-0.7, 0.4)", class = "Intercept", dpar = "sigma")
prior <- prior + sigma_prior
prior$source <- "family"
prior[is.na(prior)] <- "" # This is because brms likes empty over NA
Expand Down
34 changes: 18 additions & 16 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -45,37 +45,39 @@
#' Replace `brms` prior distributions
#'
#' This function takes `old_prior` and replaces any prior distributions
#' contained in it by the corresponding prior distribution in `new_prior`.
#' If there is a prior distribution in `new_prior` with no match in `old_prior`
#' then the function will error and give the name of the new prior distribution
#' with no match.
#' contained in it by the corresponding prior distribution in `prior`. If there
#' is a prior distribution in `prior` with no match in `old_prior` then this
#' function can optionally give a warning.
#'
#' @param old_prior One or more prior distributions in the class `brmsprior`
#' @param new_prior One or more prior distributions in the class `brmsprior`
#' @param prior One or more prior distributions in the class `brmsprior`
#' @param warn If `TRUE` then a warning will be displayed if a `new_prior` is
#' provided for which there is no matching `old_prior`. Defaults to `FALSE`
#' @autoglobal
#' @keywords internal
.replace_prior <- function(old_prior, new_prior) {
if (is.null(new_prior)) {
.replace_prior <- function(old_prior, prior, warn = FALSE) {
if (is.null(prior)) {
return(old_prior)
}
cols <- c("class", "coef", "group", "resp", "dpar", "nlpar", "lb", "ub")
prior <- dplyr::full_join(
old_prior, new_prior, by = cols, suffix = c("_old", "_new")
old_prior, prior, by = cols, suffix = c("_old", "_new")
)

if (any(is.na(prior$prior_old))) {
missing_prior <- utils::capture.output(print(
prior |>
filter(is.na(.data$prior_old)) |>
select(
prior = prior_new, dplyr::all_of(cols), source = source_new
)
select(prior = prior_new, dplyr::all_of(cols), source = source_new)
))
msg <- c(
"i" = "No available prior to replace in old_prior found for:",
missing_prior
)
cli_abort(message = msg)
if (warn) {
msg <- c(
"!" = "One or more priors have no match in existing parameters:",
missing_prior,
"i" = "To remove this warning consider changing prior specification."
)
cli_warn(message = msg)
}
}

prior <- prior |>
Expand Down
14 changes: 8 additions & 6 deletions man/dot-replace_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/epidist.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 5 additions & 5 deletions man/epidist.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/epidist_family_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/epidist_family_prior.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/epidist_family_prior.lognormal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/epidist_model_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/epidist_model_prior.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 8 additions & 5 deletions man/epidist_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion tests/testthat/test-int-latent_individual.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ test_that("epidist.epidist_latent_individual fits and the MCMC converges in the
expect_convergence(fit)
})

test_that("epidist.epidist_latent_individual fit and the MCMC converges when setting sigma = 1 (a constant)", { # nolint: line_length_linter.
test_that("epidist.epidist_latent_individual fits, the MCMC converges, and the draws of sigma are indeed a constant, when setting sigma = 1 (a constant)", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
Expand All @@ -95,6 +95,8 @@ test_that("epidist.epidist_latent_individual fit and the MCMC converges when set
expect_s3_class(fit_constant, "brmsfit")
expect_s3_class(fit_constant, "epidist_fit")
expect_convergence(fit_constant)
sigma <- rstan::extract(fit_constant$fit, pars = "sigma")$sigma
expect_true(all(sigma == 1))
})

test_that("epidist.epidist_latent_individual Stan code has no syntax errors and compiles with lognormal family as a string", { # nolint: line_length_linter.
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ test_that(".replace_prior errors when passed a new prior without a match in old_
new_prior <- brms::prior("normal(0, 5)", class = "Intercept") +
brms::prior("normal(0, 5)", class = "Intercept", dpar = "sigma") +
brms::prior("normal(0, 5)", class = "Intercept", dpar = "shape")
expect_error(.replace_prior(old_prior, new_prior))
expect_warning(.replace_prior(old_prior, new_prior, warn = TRUE))
})

test_that(".add_dpar_info works as expected for the lognormal and gamma families", { # nolint: line_length_linter.
Expand Down

0 comments on commit 92c5104

Please sign in to comment.