Skip to content

Commit

Permalink
Automatically init model methods, add inc_warmup argument to unconstr…
Browse files Browse the repository at this point in the history
…ain_draws
  • Loading branch information
andrjohns committed May 24, 2024
1 parent a4678d5 commit 554c0a0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 19 deletions.
32 changes: 17 additions & 15 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,6 @@ init_model_methods <- function(seed = 1, verbose = FALSE, hessian = FALSE) {
"errors that you encounter")
}
if (is.null(private$model_methods_env_$model_ptr)) {
if (verbose) {
message("Compiling additional model methods...")
}
expose_model_methods(private$model_methods_env_, verbose, hessian)
}
if (!("model_ptr_" %in% ls(private$model_methods_env_))) {
Expand Down Expand Up @@ -527,6 +524,8 @@ CmdStanFit$set("public", name = "unconstrain_variables", value = unconstrain_var
#' @param draws A `posterior::draws_*` object.
#' @param format (string) The format of the returned draws. Must be a valid
#' format from the \pkg{posterior} package.
#' @param inc_warmup (logical) Should warmup draws be included? Defaults to
#' `FALSE`.
#'
#' @examples
#' \dontrun{
Expand Down Expand Up @@ -559,22 +558,25 @@ unconstrain_draws <- function(files = NULL, draws = NULL,
call. = FALSE)
}
if (!is.null(files)) {
read_csv <- read_cmdstan_csv(files = files, format = "draws_matrix")
draws <- read_csv$post_warmup_draws
}
if (!is.null(draws)) {
draws <- maybe_convert_draws_format(draws, "draws_matrix")
}
} else {
if (is.null(private$draws_)) {
if (!length(self$output_files(include_failed = FALSE))) {
stop("Fitting failed. Unable to retrieve the draws.", call. = FALSE)
read_csv <- read_cmdstan_csv(files = files)
if (inc_warmup) {
draws <- posterior::bind_draws(read_csv$warmup_draws,
read_csv$post_warmup_draws,
along = "iteration")
} else {
draws <- read_csv$post_warmup_draws
}
} else if (!is.null(draws)) {
if (inc_warmup) {
message("'inc_warmup' cannot be used with a draws object. Ignoring.")
}
private$read_csv_(format = "draws_df")
}
draws <- maybe_convert_draws_format(private$draws_, "draws_matrix")
} else {
draws <- self$draws(inc_warmup = inc_warmup)
}

draws <- maybe_convert_draws_format(draws, "draws_matrix")

chains <- posterior::nchains(draws)

model_par_names <- self$metadata()$stan_variables[self$metadata()$stan_variables != "lp__"]
Expand Down
7 changes: 6 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,9 @@ rcpp_source_stan <- function(code, env, verbose = FALSE, ...) {
}

expose_model_methods <- function(env, verbose = FALSE, hessian = FALSE) {
if (rlang::is_interactive()) {
message("Compiling additional model methods...")
}
code <- c(env$hpp_code_,
readLines(system.file("include", "model_methods.cpp",
package = "cmdstanr", mustWork = TRUE)))
Expand Down Expand Up @@ -1034,7 +1037,9 @@ expose_stan_functions <- function(function_env, global = FALSE, verbose = FALSE)
})
}
} else {
message("Compiling standalone functions...")
if (rlang::is_interactive()) {
message("Compiling standalone functions...")
}
compile_functions(function_env, verbose, global)
}
invisible(NULL)
Expand Down
3 changes: 3 additions & 0 deletions man/fit-method-unconstrain_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 16 additions & 3 deletions tests/testthat/test-model-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -179,21 +179,34 @@ test_that("unconstrain_draws returns correct values", {
mod <- cmdstan_model(write_stan_file(model_code),
compile_model_methods = TRUE,
force_recompile = TRUE)
fit <- mod$sample(data = list(N = 0), chains = 2)
fit <- mod$sample(data = list(N = 0), chains = 2, save_warmup = TRUE)
fit_no_warmup <- mod$sample(data = list(N = 0), chains = 2)

x_draws <- fit$draws(format = "draws_df")$x

x_draws_warmup <- fit$draws(format = "draws_df", inc_warmup = TRUE)$x

# Unconstrain all internal draws
unconstrained_internal_draws <- fit$unconstrain_draws()
unconstrained_internal_draws_warmup <- fit$unconstrain_draws(inc_warmup = TRUE)
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_internal_draws))

expect_equal(as.numeric(x_draws_warmup), as.numeric(unconstrained_internal_draws_warmup))

expect_error({unconstrained_internal_draws <- fit_no_warmup$unconstrain_draws(inc_warmup = TRUE)},
"Warmup draws were requested from a fit object without them! Please rerun the model with save_warmup = TRUE.")

# Unconstrain external CmdStan CSV files
unconstrained_csv <- fit$unconstrain_draws(files = fit$output_files())
unconstrained_csv_warmup <- fit$unconstrain_draws(files = fit$output_files(),
inc_warmup = TRUE)
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_csv))
expect_equal(as.numeric(x_draws_warmup), as.numeric(unconstrained_csv_warmup))

# Unconstrain existing draws object
unconstrained_draws <- fit$unconstrain_draws(draws = fit$draws())
expect_equal(as.numeric(x_draws), as.numeric(unconstrained_draws))

expect_message(fit$unconstrain_draws(draws = fit$draws(), inc_warmup = TRUE),
"'inc_warmup' cannot be used with a draws object. Ignoring.")

# With a lower-bounded constraint, the parameter draws should be the
# exponentiation of the unconstrained draws
Expand Down

0 comments on commit 554c0a0

Please sign in to comment.