Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get parameters #646

Merged
merged 14 commits into from
May 3, 2024
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ export(forecast_secondary)
export(gamma_dist_def)
export(generation_time_opts)
export(get_dist)
export(get_distribution)
export(get_generation_time)
export(get_incubation_period)
export(get_parameters)
export(get_pmf)
export(get_raw_result)
export(get_regional_results)
export(get_regions)
Expand All @@ -77,6 +80,7 @@ export(growth_to_R)
export(lognorm_dist_def)
export(make_conf)
export(map_prob_change)
export(new_dist_spec)
export(obs_opts)
export(opts_list)
export(plot_estimates)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.
* The function `init_cumulative_fit()` has been deprecated. By @jamesmbaazam in #541 and reviewed by @sbfnk.
* The interface to generating delay distributions has been completely overhauled. Instead of calling `dist_spec()` users now specify distributions using functions that represent the available distributions, i.e. `LogNormal()`, `Gamma()` and `Fixed()`. Uncertainty is specified using calls of the same nature, to `Normal()`. More information on the underlying design can be found in `inst/dev/design_dist.md` By @sbfnk in #504 and reviewed by @seabbs.
* The accessor functions `get_parameters()`, `get_pmf()`, and `get_distribution()` have been added to extract elements of a <dist_spec> object. By @sbfnk in #646 and reviewed by @jamesmbaazam.
* The functions `sample_approx_dist()`, `report_cases()`, and `adjust_infection_reports()` have been deprecated as the functionality they provide can now be achieved with `simulate_secondary()`. See #597 by @jamesmbaazam and reviewed by @sbfnk.

## Documentation
Expand Down
4 changes: 1 addition & 3 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ create_clean_reported_cases <- function(data, horizon = 0,
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases; data frame with a column "confirm" that may contain NA values
#' @param burn_in; integer (default 0). Number of days to remove from the
#' start of the time series be filtered out.
#' @param cases data frame with a column "confirm" that may contain NA values
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @importFrom data.table setDT
Expand Down
4 changes: 2 additions & 2 deletions R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ dist_spec <- function(distribution = c(
mean = Normal(mean, mean_sd),
sd = Normal(sd, sd_sd)
)
params_mean <- vapply(temp_dist[[1]]$parameters, mean, numeric(1))
params_sd <- vapply(temp_dist[[1]]$parameters, sd_dist, numeric(1))
params_mean <- vapply(get_parameters(temp_dist), mean, numeric(1))
params_sd <- vapply(get_parameters(temp_dist), sd_dist, numeric(1))
} else if (distribution == "normal") {
params_mean <- c(mean = mean, sd = sd)
params_sd <- c(mean = mean_sd, sd = sd_sd)
Expand Down
113 changes: 99 additions & 14 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,10 @@ print.dist_spec <- function(x, ...) {
} else if (x[[i]]$distribution == "fixed") {
## fixed
cat(indent_str, "- fixed value:\n", sep = "")
if (is.numeric(x[[i]]$parameters$value)) {
cat(indent_str, " ", x[[i]]$parameters$value, "\n", sep = "")
if (is.numeric(get_parameters(x, i)$value)) {
cat(indent_str, " ", get_parameters(x, i)$value, "\n", sep = "")
} else {
.print.dist_spec(x[[i]]$parameters$value, indent = indent + 4)
.print.dist_spec(get_parameters(x, i)$value, indent = indent + 4)
}
} else {
## parametric
Expand All @@ -595,18 +595,18 @@ print.dist_spec <- function(x, ...) {
}
cat(":\n")
## loop over natural parameters and print
for (param in names(x[[i]]$parameters)) {
for (param in names(get_parameters(x, i))) {
cat(
indent_str, " ", param, ":\n", sep = ""
)
if (is.numeric(x[[i]]$parameters[[param]])) {
if (is.numeric(get_parameters(x, i)[[param]])) {
cat(
indent_str, " ",
signif(x[[i]]$parameters[[param]], digits = 2), "\n",
signif(get_parameters(x, i)[[param]], digits = 2), "\n",
sep = ""
)
} else {
.print.dist_spec(x[[i]]$parameters[[param]], indent = indent + 4)
.print.dist_spec(get_parameters(x, i)[[param]], indent = indent + 4)
}
}
}
Expand Down Expand Up @@ -654,12 +654,12 @@ plot.dist_spec <- function(x, ...) {
for (i in seq_along(x)) {
if (x[[i]]$distribution == "nonparametric") {
# Fixed distribution
pmf <- x[[i]]$pmf
pmf <- get_pmf(x, i)
dist_name <- paste0("Nonparametric", " (ID: ", i, ")")
} else {
# Uncertain distribution
c_dist <- discretise(fix_dist(extract_single_dist(x, i)))
pmf <- c_dist[[1]]$pmf
pmf <- get_pmf(c_dist)
dist_name <- paste0(
ifelse(is.na(dist_sd[i]), "Uncertain ", ""),
x[[i]]$distribution, " (ID: ", i, ")"
Expand Down Expand Up @@ -951,14 +951,12 @@ extract_params <- function(params, distribution) {
#' @inheritParams extract_params
#' @importFrom purrr walk
#' @return A `dist_spec` of the given specification.
#' @keywords internal
#' @export
#' @examples
#' \dontrun{
#' new_dist_spec(
#' params = list(mean = 2, sd = 1, max = Inf),
#' distribution = "normal"
#' )
#' }
new_dist_spec <- function(params, distribution) {
if (distribution == "nonparametric") {
## nonparametric distribution
Expand All @@ -968,8 +966,12 @@ new_dist_spec <- function(params, distribution) {
)
} else {
## process min/max first
max <- params$max
params$max <- NULL
if (is.null(params$max)) {
max <- Inf
} else {
max <- params$max
params$max <- NULL
}
## extract parameters and convert all to dist_spec
params <- extract_params(params, distribution)
## fixed distribution
Expand Down Expand Up @@ -1100,3 +1102,86 @@ convert_to_natural <- function(params, distribution) {
}
return(params)
}

##' Perform checks for `<dist_spec>` `get_...` functions
##'
##' @param x A `<dist_spec>`.
##' @param id Integer; the id of the distribution to get parameters of (if x is
##' a composite distribution). If `x` is a single distribution this is ignored
##' and can be left as `NULL`.
##' @return The id to use.
##' @keywords internal
##' @author Sebastian Funk
get_dist_spec_id <- function(x, id) {
if (!is.null(id) && id > length(x)) {
stop(
"`id` can't be greater than the number of distributions (", length(x),
")."
)
}
if (length(x) > 1) {
if (is.null(id)) {
stop("`id` must be specified when `x` is a composite distribution.")
}
} else {
id <- 1
}
return(id)
}

##' Get parameters of a parametric distribution
##'
##' @inheritParams get_dist_spec_id
##' @description `r lifecycle::badge("experimental")`
##' @return A list of parameters of the distribution.
##' @export
##' @examples
##' dist <- Gamma(shape = 3, rate = 2)
##' get_parameters(dist)
get_parameters <- function(x, id = NULL) {
if (!is(x, "dist_spec")) {
stop("Can only get parameters of a <dist_spec>.")
}
id <- get_dist_spec_id(x, id)
if (x[[id]]$distribution == "nonparametric") {
stop("Cannot get parameters of a nonparametric distribution.")
}
return(x[[id]]$parameters)
}

##' Get the probability mass function of a nonparametric distribution
##'
##' @inheritParams get_dist_spec_id
##' @description `r lifecycle::badge("experimental")`
##' @return The pmf of the distribution
##' @export
##' @examples
##' dist <- discretise(Gamma(shape = 3, rate = 2, max = 10))
##' get_pmf(dist)
get_pmf <- function(x, id = NULL) {
if (!is(x, "dist_spec")) {
stop("Can only get pmf of a <dist_spec>.")
}
id <- get_dist_spec_id(x, id)
if (x[[id]]$distribution != "nonparametric") {
stop("Cannot get pmf of a parametric distribution.")
}
return(x[[id]]$pmf)
}

##' Get the distribution of a [dist_spec()]
##'
##' @inheritParams get_dist_spec_id
##' @description `r lifecycle::badge("experimental")`
##' @return A character string naming the distribution (or "nonparametric")
##' @export
##' @examples
##' dist <- Gamma(shape = 3, rate = 2, max = 10)
##' get_distribution(dist)
get_distribution <- function(x, id = NULL) {
if (!is(x, "dist_spec")) {
stop("Can only get distribution of a <dist_spec>.")
}
id <- get_dist_spec_id(x, id)
return(x[[id]]$distribution)
}
11 changes: 7 additions & 4 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
#' # illustrative purposes only.
#' out <- epinow(
#' example_truncated[[5]],
#' truncation = est$dist
#' truncation = trunc_opts(est$dist)
#' )
#' plot(out)
#' options(old_opts)
Expand Down Expand Up @@ -291,9 +291,12 @@ estimate_truncation <- function(data, max_truncation, trunc_max = 10,
parameters <- purrr::map(seq_along(params_mean), function(id) {
Normal(params_mean[id], params_sd[id])
})
names(parameters) <- natural_params(truncation[[1]]$distribution)
out$dist <- truncation
out$dist[[1]]$parameters <- parameters
names(parameters) <- natural_params(get_distribution(truncation))
parameters$max <- max(truncation)
out$dist <- new_dist_spec(
params = parameters,
distribution = get_distribution(truncation)
)

# summarise reconstructed observations
recon_obs <- extract_stan_param(fit, "recon_obs",
Expand Down
6 changes: 5 additions & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ reference:
- apply_tolerance
- collapse
- discretise
- contains("dist")
- contains("_dist")
- contains("dist_")
- get_parameters
- get_pmf
- get_distribution
- title: Simulate
desc: Functions to help with simulating data or mapping to reported cases
contents:
Expand Down
10 changes: 5 additions & 5 deletions data-raw/truncated.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ library("EpiNow2")
#' @keywords internal
apply_truncation <- function(index, data, dist) {
set.seed(index)
if (dist[[1]]$distribution == 0) {
if (get_distribution(dist) == "lognormal") {
dfunc <- dlnorm
} else {
dfunc <- dgamma
Expand All @@ -20,12 +20,12 @@ apply_truncation <- function(index, data, dist) {
dfunc(
seq_len(max(dist) + 1),
rnorm(1,
dist[[1]]$parameters$meanlog[[1]]$parameters$mean,
dist[[1]]$parameters$meanlog[[1]]$parameters$sd
get_parameters(get_parameters(dist)$meanlog)$mean,
get_parameters(get_parameters(dist)$meanlog)$sd
),
rnorm(1,
dist[[1]]$parameters$sdlog[[1]]$parameters$mean,
dist[[1]]$parameters$sdlog[[1]]$parameters$sd
get_parameters(get_parameters(dist)$sdlog)$mean,
get_parameters(get_parameters(dist)$sdlog)$sd
)
)
)
Expand Down
5 changes: 1 addition & 4 deletions man/create_complete_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/estimate_truncation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/get_dist_spec_id.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/get_distribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/get_parameters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading