Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return variable names and dimensions #256

Merged
merged 8 commits into from
Jul 30, 2020
19 changes: 13 additions & 6 deletions R/read_csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,19 @@ read_cmdstan_csv <- function(files,
metadata$lines_to_skip <- NULL
metadata$model_params <- repair_variable_names(metadata$model_params)
repaired_variables <- repair_variable_names(variables)
if (metadata$method == "variational") {
metadata$model_params <- metadata$model_params[metadata$model_params != "lp__"]
metadata$model_params <- gsub("log_p__", "lp__", metadata$model_params)
metadata$model_params <- gsub("log_g__", "lp_approx__", metadata$model_params)
repaired_variables <- repaired_variables[repaired_variables != "lp__"]
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
}

model_param_dims <- variable_dims(metadata$model_params)
metadata$stan_variable_dims <- model_param_dims
metadata$stan_variables <- names(model_param_dims)

if (metadata$method == "sample") {
if (!is.null(warmup_draws)) {
posterior::variables(warmup_draws) <- repaired_variables
Expand All @@ -339,12 +352,6 @@ read_cmdstan_csv <- function(files,
post_warmup_sampler_diagnostics = post_warmup_sampler_diagnostics_draws
)
} else if (metadata$method == "variational") {
metadata$model_params <- metadata$model_params[metadata$model_params != "lp__"]
metadata$model_params <- gsub("log_p__", "lp__", metadata$model_params)
metadata$model_params <- gsub("log_g__", "lp_approx__", metadata$model_params)
repaired_variables <- repaired_variables[repaired_variables != "lp__"]
repaired_variables <- gsub("log_p__", "lp__", repaired_variables)
repaired_variables <- gsub("log_g__", "lp_approx__", repaired_variables)
if (!is.null(variational_draws)) {
posterior::variables(variational_draws) <- repaired_variables
}
Expand Down
33 changes: 33 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -331,3 +331,36 @@ matching_variables <- function(variable_filters, variables) {
not_found = not_found
)
}

#' Returns a list of dimensions for the input variables.
#'
#' @noRd
#' @param variable_names A character vector of variable names including all
#' individual elements (e.g., `c("beta[1]", "beta[2]")`, not just `"beta"`).
#' @return A list giving the dimensions of the variables. The equivalent of the
#' `par_dims` slot of RStan's stanfit objects, except that scalars have
#' dimension `1` instead of `0`.
#' @note For this function to return the correct dimensions the input must be
#' already sorted in ascending order. Since CmdStan always has the variables
#' sorted correctly we avoid a sort by not sorting again here.
#'
variable_dims <- function(variable_names = NULL) {
if (is.null(variable_names)) {
return(NULL)
}
dims <- list()
uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names))
var_names <- gsub("\\]", "", variable_names)
for (var in uniq_variable_names) {
pattern <- paste0("^", var, "\\[")
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
var_indices <- var_names[grep(pattern, var_names)]
var_indices <- gsub(pattern, "", var_indices)
if (length(var_indices)) {
var_indices <- strsplit(var_indices[length(var_indices)], ",")[[1]]
dims[[var]] <- as.numeric(var_indices)
} else {
dims[[var]] <- 1
}
}
dims
}
32 changes: 32 additions & 0 deletions tests/testthat/test-csv.R
Original file line number Diff line number Diff line change
Expand Up @@ -480,3 +480,35 @@ test_that("read_cmdstan_csv() errors for files from different methods", {
"Supplied CSV files were produced by different methods and need to be read in separately!"
)
})

test_that("stan_variables and stan_variable_dims works in read_cdmstan_csv()", {
skip_on_cran()
bern_opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files())
bern_vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files())
log_opt <- read_cmdstan_csv(fit_logistic_optimize$output_files())
log_vi <- read_cmdstan_csv(fit_logistic_variational$output_files())
bern_samp <- read_cmdstan_csv(fit_bernoulli_thin_1$output_files())
log_samp <- read_cmdstan_csv(fit_logistic_thin_1$output_files())
gq <- read_cmdstan_csv(fit_gq$output_files())

expect_equal(bern_opt$metadata$stan_variables, c("lp__", "theta"))
expect_equal(bern_vi$metadata$stan_variables, c("lp__", "lp_approx__", "theta"))
expect_equal(bern_samp$metadata$stan_variables, c("lp__", "theta"))

expect_equal(log_opt$metadata$stan_variables, c("lp__", "alpha", "beta"))
expect_equal(log_vi$metadata$stan_variables, c("lp__", "lp_approx__", "alpha", "beta"))
expect_equal(log_samp$metadata$stan_variables, c("lp__", "alpha", "beta"))

expect_equal(gq$metadata$stan_variables, c("y_rep","sum_y"))

expect_equal(bern_opt$metadata$stan_variable_dims, list(lp__ = 1, theta = 1))
expect_equal(bern_vi$metadata$stan_variable_dims, list(lp__ = 1, lp_approx__ = 1, theta = 1))
expect_equal(bern_samp$metadata$stan_variable_dims, list(lp__ = 1, theta = 1))

expect_equal(log_opt$metadata$stan_variable_dims, list(lp__ = 1, alpha = 1, beta = 3))
expect_equal(log_vi$metadata$stan_variable_dims, list(lp__ = 1, lp_approx__ = 1, alpha = 1, beta = 3))
expect_equal(log_samp$metadata$stan_variable_dims, list(lp__ = 1, alpha = 1, beta = 3))

expect_equal(gq$metadata$stan_variable_dims, list(y_rep = 10, sum_y = 1))
})

24 changes: 24 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,27 @@ test_that("cmdstan_make_local() works", {
cmdstan_make_local(cpp_options = as.list(exisiting_make_local), append = FALSE)
})

test_that("variable_dims() works", {
expect_null(variable_dims(NULL))

vars <- c("a", "b[1]", "b[2]", "b[3]", "c[1,1]", "c[1,2]")
vars_dims <- list(a = 1, b = 3, c = c(1,2))
expect_equal(variable_dims(vars), vars_dims)

vars <- c("a", "b")
vars_dims <- list(a = 1, b = 1)
expect_equal(variable_dims(vars), vars_dims)

vars <- c("c[1,1]", "c[1,2]", "c[1,3]", "c[2,1]", "c[2,2]", "c[2,3]", "b[1]", "b[2]", "b[3]", "b[4]")
vars_dims <- list(c = c(2,3), b = 4)
expect_equal(variable_dims(vars), vars_dims)

# make sure not confused by one name being last substring of another name
vars <- c("a[1]", "a[2]", "aa[1]", "aa[2]", "aa[3]")
expect_equal(variable_dims(vars), list(a = 2, aa = 3))

# wrong dimensions for descending order
vars <- c("c[1,1]", "c[1,2]", "c[1,3]", "c[2,3]", "c[2,2]", "c[2,1]", "b[4]", "b[2]", "b[3]", "b[1]")
vars_dims <- list(c = c(2,1), b = 1)
expect_equal(variable_dims(vars), vars_dims)
rok-cesnovar marked this conversation as resolved.
Show resolved Hide resolved
})