diff --git a/NAMESPACE b/NAMESPACE index 08ecab055..76f052cae 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -54,6 +54,7 @@ importFrom(checkmate,assert_numeric) importFrom(checkmate,assert_true) importFrom(cli,cli_abort) importFrom(cli,cli_inform) +importFrom(cli,cli_warn) importFrom(dplyr,filter) importFrom(dplyr,mutate) importFrom(dplyr,select) diff --git a/R/epidist-package.R b/R/epidist-package.R index 293ed97f5..0859a1a6a 100644 --- a/R/epidist-package.R +++ b/R/epidist-package.R @@ -8,6 +8,6 @@ #' @importFrom brms bf prior #' @importFrom checkmate assert_data_frame assert_names assert_integer #' assert_true assert_factor assert_numeric -#' @importFrom cli cli_abort cli_inform cli_abort +#' @importFrom cli cli_abort cli_inform cli_abort cli_warn ## usethis namespace: end NULL diff --git a/R/generics.R b/R/generics.R index 851288518..cd5d70147 100644 --- a/R/generics.R +++ b/R/generics.R @@ -30,7 +30,7 @@ epidist_stancode <- function(data, ...) { #' Fit epidemiological delay distributions using a `brms` interface #' #' @inheritParams epidist_validate -#' @param formula A formula object created using `brms::bf`. A formula must be +#' @param formula A formula object created using [brms::bf()]. A formula must be #' provided for the distributional parameter `mu` common to all `brms` families. #' Optionally, formulas may also be provided for additional distributional #' parameters. @@ -46,10 +46,10 @@ epidist_stancode <- function(data, ...) { #' @param backend Character string naming the package to use as the backend for #' fitting the Stan model. Options are `"rstan"` and `"cmdstanr"` (the default). #' This option is passed directly through to `fn`. -#' @param fn The internal function to be called. By default this is `brms::brm`, -#' which performs inference for the specified model. Other options -#' `brms::make_stancode`, which returns the Stan code for the specified model, -#' and `brms::make_standata` which returns the data passed to Stan. These +#' @param fn The internal function to be called. By default this is +#' [brms::brm()] which performs inference for the specified model. Other options +#' [brms::make_stancode()], which returns the Stan code for the specified model, +#' and [brms::make_standata()] which returns the data passed to Stan. These #' options may be useful for model debugging and extensions. #' @param ... Additional arguments for method. #' @family generics diff --git a/R/prior.R b/R/prior.R index c4ff2175b..839abae3b 100644 --- a/R/prior.R +++ b/R/prior.R @@ -7,12 +7,15 @@ #' 2. Family specific prior distributions from [epidist_family_prior()] #' 3. User provided prior distributions #' Each element of this list overwrites previous elements, such that user -#' provided prior distribution have the highest priority. +#' provided prior distribution have the highest priority. At the third stage, +#' if a prior distribution is provided which is not included in the model, then +#' a warning will be shown. To prevent this warning, do not pass prior +#' distributions for parameters which are not in the model. #' -#' @param data ... -#' @param family ... -#' @param formula ... -#' @param prior ... +#' @param data A `data.frame` containing line list data +#' @param family Output of a call to `brms::brmsfamily()` +#' @param formula A formula object created using `brms::bf()` +#' @param prior User provided prior distribution created using `brms::prior()` #' @rdname epidist_prior #' @family prior #' @export @@ -23,7 +26,8 @@ epidist_prior <- function(data, family, formula, prior) { default <- brms::default_prior(formula, data = data) model <- epidist_model_prior(data, formula) family <- epidist_family_prior(family, formula) - prior <- Reduce(.replace_prior, list(default, model, family, prior)) + internal_prior <- Reduce(.replace_prior, list(default, model, family)) + prior <- .replace_prior(internal_prior, prior, warn = TRUE) return(prior) } @@ -90,15 +94,7 @@ epidist_family_prior.default <- function(family, formula, ...) { #' @export epidist_family_prior.lognormal <- function(family, formula, ...) { prior <- prior("normal(1, 1)", class = "Intercept") - if ("sigma" %in% names(formula$pfix)) { - # Case with sigma fixed to a constant - sigma_prior <- NULL - } else { - # Case with a model on sigma - sigma_prior <- prior( - "normal(-0.7, 0.4)", class = "Intercept", dpar = "sigma" - ) - } + sigma_prior <- prior("normal(-0.7, 0.4)", class = "Intercept", dpar = "sigma") prior <- prior + sigma_prior prior$source <- "family" prior[is.na(prior)] <- "" # This is because brms likes empty over NA diff --git a/R/utils.R b/R/utils.R index aa4ce62ee..39d5f28b0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -45,37 +45,39 @@ #' Replace `brms` prior distributions #' #' This function takes `old_prior` and replaces any prior distributions -#' contained in it by the corresponding prior distribution in `new_prior`. -#' If there is a prior distribution in `new_prior` with no match in `old_prior` -#' then the function will error and give the name of the new prior distribution -#' with no match. +#' contained in it by the corresponding prior distribution in `prior`. If there +#' is a prior distribution in `prior` with no match in `old_prior` then this +#' function can optionally give a warning. #' #' @param old_prior One or more prior distributions in the class `brmsprior` -#' @param new_prior One or more prior distributions in the class `brmsprior` +#' @param prior One or more prior distributions in the class `brmsprior` +#' @param warn If `TRUE` then a warning will be displayed if a `new_prior` is +#' provided for which there is no matching `old_prior`. Defaults to `FALSE` #' @autoglobal #' @keywords internal -.replace_prior <- function(old_prior, new_prior) { - if (is.null(new_prior)) { +.replace_prior <- function(old_prior, prior, warn = FALSE) { + if (is.null(prior)) { return(old_prior) } cols <- c("class", "coef", "group", "resp", "dpar", "nlpar", "lb", "ub") prior <- dplyr::full_join( - old_prior, new_prior, by = cols, suffix = c("_old", "_new") + old_prior, prior, by = cols, suffix = c("_old", "_new") ) if (any(is.na(prior$prior_old))) { missing_prior <- utils::capture.output(print( prior |> filter(is.na(.data$prior_old)) |> - select( - prior = prior_new, dplyr::all_of(cols), source = source_new - ) + select(prior = prior_new, dplyr::all_of(cols), source = source_new) )) - msg <- c( - "i" = "No available prior to replace in old_prior found for:", - missing_prior - ) - cli_abort(message = msg) + if (warn) { + msg <- c( + "!" = "One or more priors have no match in existing parameters:", + missing_prior, + "i" = "To remove this warning consider changing prior specification." + ) + cli_warn(message = msg) + } } prior <- prior |> diff --git a/man/dot-replace_prior.Rd b/man/dot-replace_prior.Rd index 5abcb889a..3c70a95de 100644 --- a/man/dot-replace_prior.Rd +++ b/man/dot-replace_prior.Rd @@ -4,18 +4,20 @@ \alias{.replace_prior} \title{Replace \code{brms} prior distributions} \usage{ -.replace_prior(old_prior, new_prior) +.replace_prior(old_prior, prior, warn = FALSE) } \arguments{ \item{old_prior}{One or more prior distributions in the class \code{brmsprior}} -\item{new_prior}{One or more prior distributions in the class \code{brmsprior}} +\item{prior}{One or more prior distributions in the class \code{brmsprior}} + +\item{warn}{If \code{TRUE} then a warning will be displayed if a \code{new_prior} is +provided for which there is no matching \code{old_prior}. Defaults to \code{FALSE}} } \description{ This function takes \code{old_prior} and replaces any prior distributions -contained in it by the corresponding prior distribution in \code{new_prior}. -If there is a prior distribution in \code{new_prior} with no match in \code{old_prior} -then the function will error and give the name of the new prior distribution -with no match. +contained in it by the corresponding prior distribution in \code{prior}. If there +is a prior distribution in \code{prior} with no match in \code{old_prior} then this +function can optionally give a warning. } \keyword{internal} diff --git a/man/epidist.Rd b/man/epidist.Rd index 0620006ab..6ace0b497 100644 --- a/man/epidist.Rd +++ b/man/epidist.Rd @@ -9,7 +9,7 @@ epidist(data, formula, family, prior, backend, fn, ...) \arguments{ \item{data}{A \code{data.frame} containing line list data.} -\item{formula}{A formula object created using \code{brms::bf}. A formula must be +\item{formula}{A formula object created using \code{\link[brms:brmsformula]{brms::bf()}}. A formula must be provided for the distributional parameter \code{mu} common to all \code{brms} families. Optionally, formulas may also be provided for additional distributional parameters.} @@ -29,10 +29,10 @@ for specifying prior distributions.} fitting the Stan model. Options are \code{"rstan"} and \code{"cmdstanr"} (the default). This option is passed directly through to \code{fn}.} -\item{fn}{The internal function to be called. By default this is \code{brms::brm}, -which performs inference for the specified model. Other options -\code{brms::make_stancode}, which returns the Stan code for the specified model, -and \code{brms::make_standata} which returns the data passed to Stan. These +\item{fn}{The internal function to be called. By default this is +\code{\link[brms:brm]{brms::brm()}} which performs inference for the specified model. Other options +\code{\link[brms:stancode]{brms::make_stancode()}}, which returns the Stan code for the specified model, +and \code{\link[brms:standata]{brms::make_standata()}} which returns the data passed to Stan. These options may be useful for model debugging and extensions.} \item{...}{Additional arguments for method.} diff --git a/man/epidist.default.Rd b/man/epidist.default.Rd index 6a7f45ac4..6dac69030 100644 --- a/man/epidist.default.Rd +++ b/man/epidist.default.Rd @@ -17,7 +17,7 @@ \arguments{ \item{data}{A \code{data.frame} containing line list data.} -\item{formula}{A formula object created using \code{brms::bf}. A formula must be +\item{formula}{A formula object created using \code{\link[brms:brmsformula]{brms::bf()}}. A formula must be provided for the distributional parameter \code{mu} common to all \code{brms} families. Optionally, formulas may also be provided for additional distributional parameters.} @@ -37,10 +37,10 @@ for specifying prior distributions.} fitting the Stan model. Options are \code{"rstan"} and \code{"cmdstanr"} (the default). This option is passed directly through to \code{fn}.} -\item{fn}{The internal function to be called. By default this is \code{brms::brm}, -which performs inference for the specified model. Other options -\code{brms::make_stancode}, which returns the Stan code for the specified model, -and \code{brms::make_standata} which returns the data passed to Stan. These +\item{fn}{The internal function to be called. By default this is +\code{\link[brms:brm]{brms::brm()}} which performs inference for the specified model. Other options +\code{\link[brms:stancode]{brms::make_stancode()}}, which returns the Stan code for the specified model, +and \code{\link[brms:standata]{brms::make_standata()}} which returns the data passed to Stan. These options may be useful for model debugging and extensions.} \item{...}{Additional arguments for method.} diff --git a/man/epidist_family_prior.Rd b/man/epidist_family_prior.Rd index b58604220..72afd6056 100644 --- a/man/epidist_family_prior.Rd +++ b/man/epidist_family_prior.Rd @@ -7,7 +7,7 @@ epidist_family_prior(family, ...) } \arguments{ -\item{family}{...} +\item{family}{Output of a call to \code{brms::brmsfamily()}} \item{...}{...} } diff --git a/man/epidist_family_prior.default.Rd b/man/epidist_family_prior.default.Rd index 7fe8f7c19..ea58083af 100644 --- a/man/epidist_family_prior.default.Rd +++ b/man/epidist_family_prior.default.Rd @@ -7,9 +7,9 @@ \method{epidist_family_prior}{default}(family, formula, ...) } \arguments{ -\item{family}{...} +\item{family}{Output of a call to \code{brms::brmsfamily()}} -\item{formula}{...} +\item{formula}{A formula object created using \code{brms::bf()}} \item{...}{...} } diff --git a/man/epidist_family_prior.lognormal.Rd b/man/epidist_family_prior.lognormal.Rd index c31141538..827b2fc86 100644 --- a/man/epidist_family_prior.lognormal.Rd +++ b/man/epidist_family_prior.lognormal.Rd @@ -7,9 +7,9 @@ \method{epidist_family_prior}{lognormal}(family, formula, ...) } \arguments{ -\item{family}{...} +\item{family}{Output of a call to \code{brms::brmsfamily()}} -\item{formula}{...} +\item{formula}{A formula object created using \code{brms::bf()}} \item{...}{...} } diff --git a/man/epidist_model_prior.Rd b/man/epidist_model_prior.Rd index 20f68fb60..36f3a35d5 100644 --- a/man/epidist_model_prior.Rd +++ b/man/epidist_model_prior.Rd @@ -7,7 +7,7 @@ epidist_model_prior(data, ...) } \arguments{ -\item{data}{...} +\item{data}{A \code{data.frame} containing line list data} \item{...}{...} } diff --git a/man/epidist_model_prior.default.Rd b/man/epidist_model_prior.default.Rd index d506cf252..f9de37fa6 100644 --- a/man/epidist_model_prior.default.Rd +++ b/man/epidist_model_prior.default.Rd @@ -7,9 +7,9 @@ \method{epidist_model_prior}{default}(data, formula, ...) } \arguments{ -\item{data}{...} +\item{data}{A \code{data.frame} containing line list data} -\item{formula}{...} +\item{formula}{A formula object created using \code{brms::bf()}} \item{...}{...} } diff --git a/man/epidist_prior.Rd b/man/epidist_prior.Rd index 8d64b4089..2454db586 100644 --- a/man/epidist_prior.Rd +++ b/man/epidist_prior.Rd @@ -8,13 +8,13 @@ family specific priors, and user provided priors} epidist_prior(data, family, formula, prior) } \arguments{ -\item{data}{...} +\item{data}{A \code{data.frame} containing line list data} -\item{family}{...} +\item{family}{Output of a call to \code{brms::brmsfamily()}} -\item{formula}{...} +\item{formula}{A formula object created using \code{brms::bf()}} -\item{prior}{...} +\item{prior}{User provided prior distribution created using \code{brms::prior()}} } \description{ This function obtains the \code{brms} default prior distributions for a particular @@ -24,7 +24,10 @@ model, then replaces these prior distributions using: \item Family specific prior distributions from \code{\link[=epidist_family_prior]{epidist_family_prior()}} \item User provided prior distributions Each element of this list overwrites previous elements, such that user -provided prior distribution have the highest priority. +provided prior distribution have the highest priority. At the third stage, +if a prior distribution is provided which is not included in the model, then +a warning will be shown. To prevent this warning, do not pass prior +distributions for parameters which are not in the model. } } \seealso{ diff --git a/tests/testthat/test-int-latent_individual.R b/tests/testthat/test-int-latent_individual.R index e33ccbd4a..c3758999f 100644 --- a/tests/testthat/test-int-latent_individual.R +++ b/tests/testthat/test-int-latent_individual.R @@ -81,7 +81,7 @@ test_that("epidist.epidist_latent_individual fits and the MCMC converges in the expect_convergence(fit) }) -test_that("epidist.epidist_latent_individual fit and the MCMC converges when setting sigma = 1 (a constant)", { # nolint: line_length_linter. +test_that("epidist.epidist_latent_individual fits, the MCMC converges, and the draws of sigma are indeed a constant, when setting sigma = 1 (a constant)", { # nolint: line_length_linter. # Note: this test is stochastic. See note at the top of this script skip_on_cran() set.seed(1) @@ -95,6 +95,8 @@ test_that("epidist.epidist_latent_individual fit and the MCMC converges when set expect_s3_class(fit_constant, "brmsfit") expect_s3_class(fit_constant, "epidist_fit") expect_convergence(fit_constant) + sigma <- rstan::extract(fit_constant$fit, pars = "sigma")$sigma + expect_true(all(sigma == 1)) }) test_that("epidist.epidist_latent_individual Stan code has no syntax errors and compiles with lognormal family as a string", { # nolint: line_length_linter. diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index c39a3c38b..0f931dc9a 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -23,7 +23,7 @@ test_that(".replace_prior errors when passed a new prior without a match in old_ new_prior <- brms::prior("normal(0, 5)", class = "Intercept") + brms::prior("normal(0, 5)", class = "Intercept", dpar = "sigma") + brms::prior("normal(0, 5)", class = "Intercept", dpar = "shape") - expect_error(.replace_prior(old_prior, new_prior)) + expect_warning(.replace_prior(old_prior, new_prior, warn = TRUE)) }) test_that(".add_dpar_info works as expected for the lognormal and gamma families", { # nolint: line_length_linter.