Skip to content

Commit

Permalink
Get parameters (#646)
Browse files Browse the repository at this point in the history
* add functions for accessing parts of `dist_spec()`

* update use to get_ functions

* add news item

* set max when constructing parametric dist_spec

* add PR number and reviewer

Co-authored-by: James Azam <[email protected]>

* clearer get_dist_spec_id function

* update documentation

---------

Co-authored-by: James Azam <[email protected]>
  • Loading branch information
sbfnk and jamesmbaazam authored May 3, 2024
1 parent 29ce5ab commit c3218d2
Show file tree
Hide file tree
Showing 17 changed files with 266 additions and 53 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,11 @@ export(forecast_secondary)
export(gamma_dist_def)
export(generation_time_opts)
export(get_dist)
export(get_distribution)
export(get_generation_time)
export(get_incubation_period)
export(get_parameters)
export(get_pmf)
export(get_raw_result)
export(get_regional_results)
export(get_regions)
Expand All @@ -77,6 +80,7 @@ export(growth_to_R)
export(lognorm_dist_def)
export(make_conf)
export(map_prob_change)
export(new_dist_spec)
export(obs_opts)
export(opts_list)
export(plot_estimates)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam.
* The function `init_cumulative_fit()` has been deprecated. By @jamesmbaazam in #541 and reviewed by @sbfnk.
* The interface to generating delay distributions has been completely overhauled. Instead of calling `dist_spec()` users now specify distributions using functions that represent the available distributions, i.e. `LogNormal()`, `Gamma()` and `Fixed()`. Uncertainty is specified using calls of the same nature, to `Normal()`. More information on the underlying design can be found in `inst/dev/design_dist.md` By @sbfnk in #504 and reviewed by @seabbs.
* The accessor functions `get_parameters()`, `get_pmf()`, and `get_distribution()` have been added to extract elements of a <dist_spec> object. By @sbfnk in #646 and reviewed by @jamesmbaazam.
* The functions `sample_approx_dist()`, `report_cases()`, and `adjust_infection_reports()` have been deprecated as the functionality they provide can now be achieved with `simulate_secondary()`. See #597 by @jamesmbaazam and reviewed by @sbfnk.

## Documentation
Expand Down
4 changes: 1 addition & 3 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ create_clean_reported_cases <- function(data, horizon = 0,
#' @description `r lifecycle::badge("stable")`
#' Creates a complete data set without NA values and appropriate indices
#'
#' @param cases; data frame with a column "confirm" that may contain NA values
#' @param burn_in; integer (default 0). Number of days to remove from the
#' start of the time series be filtered out.
#' @param cases data frame with a column "confirm" that may contain NA values
#'
#' @return A data frame without NA values, with two columns: confirm (number)
#' @importFrom data.table setDT
Expand Down
4 changes: 2 additions & 2 deletions R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ dist_spec <- function(distribution = c(
mean = Normal(mean, mean_sd),
sd = Normal(sd, sd_sd)
)
params_mean <- vapply(temp_dist[[1]]$parameters, mean, numeric(1))
params_sd <- vapply(temp_dist[[1]]$parameters, sd_dist, numeric(1))
params_mean <- vapply(get_parameters(temp_dist), mean, numeric(1))
params_sd <- vapply(get_parameters(temp_dist), sd_dist, numeric(1))
} else if (distribution == "normal") {
params_mean <- c(mean = mean, sd = sd)
params_sd <- c(mean = mean_sd, sd = sd_sd)
Expand Down
113 changes: 99 additions & 14 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,10 @@ print.dist_spec <- function(x, ...) {
} else if (x[[i]]$distribution == "fixed") {
## fixed
cat(indent_str, "- fixed value:\n", sep = "")
if (is.numeric(x[[i]]$parameters$value)) {
cat(indent_str, " ", x[[i]]$parameters$value, "\n", sep = "")
if (is.numeric(get_parameters(x, i)$value)) {
cat(indent_str, " ", get_parameters(x, i)$value, "\n", sep = "")
} else {
.print.dist_spec(x[[i]]$parameters$value, indent = indent + 4)
.print.dist_spec(get_parameters(x, i)$value, indent = indent + 4)
}
} else {
## parametric
Expand All @@ -595,18 +595,18 @@ print.dist_spec <- function(x, ...) {
}
cat(":\n")
## loop over natural parameters and print
for (param in names(x[[i]]$parameters)) {
for (param in names(get_parameters(x, i))) {
cat(
indent_str, " ", param, ":\n", sep = ""
)
if (is.numeric(x[[i]]$parameters[[param]])) {
if (is.numeric(get_parameters(x, i)[[param]])) {
cat(
indent_str, " ",
signif(x[[i]]$parameters[[param]], digits = 2), "\n",
signif(get_parameters(x, i)[[param]], digits = 2), "\n",
sep = ""
)
} else {
.print.dist_spec(x[[i]]$parameters[[param]], indent = indent + 4)
.print.dist_spec(get_parameters(x, i)[[param]], indent = indent + 4)
}
}
}
Expand Down Expand Up @@ -654,12 +654,12 @@ plot.dist_spec <- function(x, ...) {
for (i in seq_along(x)) {
if (x[[i]]$distribution == "nonparametric") {
# Fixed distribution
pmf <- x[[i]]$pmf
pmf <- get_pmf(x, i)
dist_name <- paste0("Nonparametric", " (ID: ", i, ")")
} else {
# Uncertain distribution
c_dist <- discretise(fix_dist(extract_single_dist(x, i)))
pmf <- c_dist[[1]]$pmf
pmf <- get_pmf(c_dist)
dist_name <- paste0(
ifelse(is.na(dist_sd[i]), "Uncertain ", ""),
x[[i]]$distribution, " (ID: ", i, ")"
Expand Down Expand Up @@ -951,14 +951,12 @@ extract_params <- function(params, distribution) {
#' @inheritParams extract_params
#' @importFrom purrr walk
#' @return A `dist_spec` of the given specification.
#' @keywords internal
#' @export
#' @examples
#' \dontrun{
#' new_dist_spec(
#' params = list(mean = 2, sd = 1, max = Inf),
#' distribution = "normal"
#' )
#' }
new_dist_spec <- function(params, distribution) {
if (distribution == "nonparametric") {
## nonparametric distribution
Expand All @@ -968,8 +966,12 @@ new_dist_spec <- function(params, distribution) {
)
} else {
## process min/max first
max <- params$max
params$max <- NULL
if (is.null(params$max)) {
max <- Inf
} else {
max <- params$max
params$max <- NULL
}
## extract parameters and convert all to dist_spec
params <- extract_params(params, distribution)
## fixed distribution
Expand Down Expand Up @@ -1100,3 +1102,86 @@ convert_to_natural <- function(params, distribution) {
}
return(params)
}

##' Perform checks for `<dist_spec>` `get_...` functions
##'
##' @param x A `<dist_spec>`.
##' @param id Integer; the id of the distribution to get parameters of (if x is
##' a composite distribution). If `x` is a single distribution this is ignored
##' and can be left as `NULL`.
##' @return The id to use.
##' @keywords internal
##' @author Sebastian Funk
get_dist_spec_id <- function(x, id) {
if (!is.null(id) && id > length(x)) {
stop(
"`id` can't be greater than the number of distributions (", length(x),
")."
)
}
if (length(x) > 1) {
if (is.null(id)) {
stop("`id` must be specified when `x` is a composite distribution.")
}
} else {
id <- 1
}
return(id)
}

##' Get parameters of a parametric distribution
##'
##' @inheritParams get_dist_spec_id
##' @description `r lifecycle::badge("experimental")`
##' @return A list of parameters of the distribution.
##' @export
##' @examples
##' dist <- Gamma(shape = 3, rate = 2)
##' get_parameters(dist)
get_parameters <- function(x, id = NULL) {
if (!is(x, "dist_spec")) {
stop("Can only get parameters of a <dist_spec>.")
}
id <- get_dist_spec_id(x, id)
if (x[[id]]$distribution == "nonparametric") {
stop("Cannot get parameters of a nonparametric distribution.")
}
return(x[[id]]$parameters)
}

##' Get the probability mass function of a nonparametric distribution
##'
##' @inheritParams get_dist_spec_id
##' @description `r lifecycle::badge("experimental")`
##' @return The pmf of the distribution
##' @export
##' @examples
##' dist <- discretise(Gamma(shape = 3, rate = 2, max = 10))
##' get_pmf(dist)
get_pmf <- function(x, id = NULL) {
if (!is(x, "dist_spec")) {
stop("Can only get pmf of a <dist_spec>.")
}
id <- get_dist_spec_id(x, id)
if (x[[id]]$distribution != "nonparametric") {
stop("Cannot get pmf of a parametric distribution.")
}
return(x[[id]]$pmf)
}

##' Get the distribution of a [dist_spec()]
##'
##' @inheritParams get_dist_spec_id
##' @description `r lifecycle::badge("experimental")`
##' @return A character string naming the distribution (or "nonparametric")
##' @export
##' @examples
##' dist <- Gamma(shape = 3, rate = 2, max = 10)
##' get_distribution(dist)
get_distribution <- function(x, id = NULL) {
if (!is(x, "dist_spec")) {
stop("Can only get distribution of a <dist_spec>.")
}
id <- get_dist_spec_id(x, id)
return(x[[id]]$distribution)
}
11 changes: 7 additions & 4 deletions R/estimate_truncation.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
#' # illustrative purposes only.
#' out <- epinow(
#' example_truncated[[5]],
#' truncation = est$dist
#' truncation = trunc_opts(est$dist)
#' )
#' plot(out)
#' options(old_opts)
Expand Down Expand Up @@ -291,9 +291,12 @@ estimate_truncation <- function(data, max_truncation, trunc_max = 10,
parameters <- purrr::map(seq_along(params_mean), function(id) {
Normal(params_mean[id], params_sd[id])
})
names(parameters) <- natural_params(truncation[[1]]$distribution)
out$dist <- truncation
out$dist[[1]]$parameters <- parameters
names(parameters) <- natural_params(get_distribution(truncation))
parameters$max <- max(truncation)
out$dist <- new_dist_spec(
params = parameters,
distribution = get_distribution(truncation)
)

# summarise reconstructed observations
recon_obs <- extract_stan_param(fit, "recon_obs",
Expand Down
6 changes: 5 additions & 1 deletion _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,11 @@ reference:
- apply_tolerance
- collapse
- discretise
- contains("dist")
- contains("_dist")
- contains("dist_")
- get_parameters
- get_pmf
- get_distribution
- title: Simulate
desc: Functions to help with simulating data or mapping to reported cases
contents:
Expand Down
10 changes: 5 additions & 5 deletions data-raw/truncated.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ library("EpiNow2")
#' @keywords internal
apply_truncation <- function(index, data, dist) {
set.seed(index)
if (dist[[1]]$distribution == 0) {
if (get_distribution(dist) == "lognormal") {
dfunc <- dlnorm
} else {
dfunc <- dgamma
Expand All @@ -20,12 +20,12 @@ apply_truncation <- function(index, data, dist) {
dfunc(
seq_len(max(dist) + 1),
rnorm(1,
dist[[1]]$parameters$meanlog[[1]]$parameters$mean,
dist[[1]]$parameters$meanlog[[1]]$parameters$sd
get_parameters(get_parameters(dist)$meanlog)$mean,
get_parameters(get_parameters(dist)$meanlog)$sd
),
rnorm(1,
dist[[1]]$parameters$sdlog[[1]]$parameters$mean,
dist[[1]]$parameters$sdlog[[1]]$parameters$sd
get_parameters(get_parameters(dist)$sdlog)$mean,
get_parameters(get_parameters(dist)$sdlog)$sd
)
)
)
Expand Down
5 changes: 1 addition & 4 deletions man/create_complete_cases.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/estimate_truncation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/get_dist_spec_id.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/get_distribution.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions man/get_parameters.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c3218d2

Please sign in to comment.